1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2020 The Khronos Group Inc.
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  *      http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  *
19  *//*!
20  * \file
21  * \brief Testing traversal control in ray query extension
22  *//*--------------------------------------------------------------------*/
23 
24 #include "vktRayQueryTraversalControlTests.hpp"
25 
26 #include <array>
27 
28 #include "vkDefs.hpp"
29 
30 #include "vktTestCase.hpp"
31 #include "vktTestGroupUtil.hpp"
32 #include "vkCmdUtil.hpp"
33 #include "vkObjUtil.hpp"
34 #include "vkBuilderUtil.hpp"
35 #include "vkBarrierUtil.hpp"
36 #include "vkBufferWithMemory.hpp"
37 #include "vkImageWithMemory.hpp"
38 #include "vkTypeUtil.hpp"
39 #include "vkImageUtil.hpp"
40 #include "deRandom.hpp"
41 #include "tcuTexture.hpp"
42 #include "tcuTextureUtil.hpp"
43 #include "tcuTestLog.hpp"
44 #include "tcuImageCompare.hpp"
45 
46 #include "vkRayTracingUtil.hpp"
47 
48 namespace vkt
49 {
50 namespace RayQuery
51 {
52 namespace
53 {
54 using namespace vk;
55 using namespace vkt;
56 
57 static const VkFlags ALL_RAY_TRACING_STAGES = VK_SHADER_STAGE_RAYGEN_BIT_KHR | VK_SHADER_STAGE_ANY_HIT_BIT_KHR |
58                                               VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_MISS_BIT_KHR |
59                                               VK_SHADER_STAGE_INTERSECTION_BIT_KHR | VK_SHADER_STAGE_CALLABLE_BIT_KHR;
60 
61 enum ShaderSourcePipeline
62 {
63     SSP_GRAPHICS_PIPELINE,
64     SSP_COMPUTE_PIPELINE,
65     SSP_RAY_TRACING_PIPELINE
66 };
67 
68 enum ShaderSourceType
69 {
70     SST_VERTEX_SHADER,
71     SST_TESSELATION_CONTROL_SHADER,
72     SST_TESSELATION_EVALUATION_SHADER,
73     SST_GEOMETRY_SHADER,
74     SST_FRAGMENT_SHADER,
75     SST_COMPUTE_SHADER,
76     SST_RAY_GENERATION_SHADER,
77     SST_INTERSECTION_SHADER,
78     SST_ANY_HIT_SHADER,
79     SST_CLOSEST_HIT_SHADER,
80     SST_MISS_SHADER,
81     SST_CALLABLE_SHADER,
82 };
83 
84 enum ShaderTestType
85 {
86     STT_GENERATE_INTERSECTION = 0,
87     STT_SKIP_INTERSECTION     = 1,
88 };
89 
90 enum BottomTestType
91 {
92     BTT_TRIANGLES,
93     BTT_AABBS
94 };
95 
96 const uint32_t TEST_WIDTH  = 8;
97 const uint32_t TEST_HEIGHT = 8;
98 
99 struct TestParams;
100 
101 class TestConfiguration
102 {
103 public:
104     virtual ~TestConfiguration();
105     virtual void initConfiguration(Context &context, TestParams &testParams) = 0;
106     virtual void fillCommandBuffer(
107         Context &context, TestParams &testParams, VkCommandBuffer commandBuffer,
108         const VkWriteDescriptorSetAccelerationStructureKHR &rayQueryAccelerationStructureWriteDescriptorSet,
109         const VkDescriptorImageInfo &resultImageInfo)                                                  = 0;
110     virtual bool verifyImage(BufferWithMemory *resultBuffer, Context &context, TestParams &testParams) = 0;
111     virtual VkFormat getResultImageFormat()                                                            = 0;
112     virtual size_t getResultImageFormatSize()                                                          = 0;
113     virtual VkClearValue getClearValue()                                                               = 0;
114 };
115 
~TestConfiguration()116 TestConfiguration::~TestConfiguration()
117 {
118 }
119 
120 struct TestParams
121 {
122     uint32_t width;
123     uint32_t height;
124     ShaderSourceType shaderSourceType;
125     ShaderSourcePipeline shaderSourcePipeline;
126     ShaderTestType shaderTestType;
127     BottomTestType bottomType;
128 };
129 
getShaderGroupHandleSize(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)130 uint32_t getShaderGroupHandleSize(const InstanceInterface &vki, const VkPhysicalDevice physicalDevice)
131 {
132     de::MovePtr<RayTracingProperties> rayTracingPropertiesKHR;
133 
134     rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
135     return rayTracingPropertiesKHR->getShaderGroupHandleSize();
136 }
137 
getShaderGroupBaseAlignment(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)138 uint32_t getShaderGroupBaseAlignment(const InstanceInterface &vki, const VkPhysicalDevice physicalDevice)
139 {
140     de::MovePtr<RayTracingProperties> rayTracingPropertiesKHR;
141 
142     rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
143     return rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
144 }
145 
makeImageCreateInfo(uint32_t width,uint32_t height,uint32_t depth,VkFormat format)146 VkImageCreateInfo makeImageCreateInfo(uint32_t width, uint32_t height, uint32_t depth, VkFormat format)
147 {
148     const VkImageCreateInfo imageCreateInfo = {
149         VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // VkStructureType sType;
150         DE_NULL,                             // const void* pNext;
151         (VkImageCreateFlags)0u,              // VkImageCreateFlags flags;
152         VK_IMAGE_TYPE_3D,                    // VkImageType imageType;
153         format,                              // VkFormat format;
154         makeExtent3D(width, height, depth),  // VkExtent3D extent;
155         1u,                                  // uint32_t mipLevels;
156         1u,                                  // uint32_t arrayLayers;
157         VK_SAMPLE_COUNT_1_BIT,               // VkSampleCountFlagBits samples;
158         VK_IMAGE_TILING_OPTIMAL,             // VkImageTiling tiling;
159         VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_TRANSFER_SRC_BIT |
160             VK_IMAGE_USAGE_TRANSFER_DST_BIT, // VkImageUsageFlags usage;
161         VK_SHARING_MODE_EXCLUSIVE,           // VkSharingMode sharingMode;
162         0u,                                  // uint32_t queueFamilyIndexCount;
163         DE_NULL,                             // const uint32_t* pQueueFamilyIndices;
164         VK_IMAGE_LAYOUT_UNDEFINED            // VkImageLayout initialLayout;
165     };
166 
167     return imageCreateInfo;
168 }
169 
registerShaderModule(const DeviceInterface & vkd,const VkDevice device,Context & context,std::vector<de::SharedPtr<Move<VkShaderModule>>> & shaderModules,std::vector<VkPipelineShaderStageCreateInfo> & shaderCreateInfos,VkShaderStageFlagBits stage,const std::string & externalNamePart,const std::string & internalNamePart)170 bool registerShaderModule(const DeviceInterface &vkd, const VkDevice device, Context &context,
171                           std::vector<de::SharedPtr<Move<VkShaderModule>>> &shaderModules,
172                           std::vector<VkPipelineShaderStageCreateInfo> &shaderCreateInfos, VkShaderStageFlagBits stage,
173                           const std::string &externalNamePart, const std::string &internalNamePart)
174 {
175     char fullShaderName[40];
176     snprintf(fullShaderName, 40, externalNamePart.c_str(), internalNamePart.c_str());
177     std::string fsn = fullShaderName;
178     if (fsn.empty())
179         return false;
180 
181     shaderModules.push_back(
182         makeVkSharedPtr(createShaderModule(vkd, device, context.getBinaryCollection().get(fsn), 0)));
183 
184     shaderCreateInfos.push_back({
185         VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, DE_NULL, (VkPipelineShaderStageCreateFlags)0,
186         stage,                       // stage
187         shaderModules.back()->get(), // shader
188         "main",
189         DE_NULL, // pSpecializationInfo
190     });
191 
192     return true;
193 }
194 
registerShaderModule(const DeviceInterface & vkd,const VkDevice device,Context & context,RayTracingPipeline & rayTracingPipeline,VkShaderStageFlagBits shaderStage,const std::string & externalNamePart,const std::string & internalNamePart,uint32_t groupIndex)195 bool registerShaderModule(const DeviceInterface &vkd, const VkDevice device, Context &context,
196                           RayTracingPipeline &rayTracingPipeline, VkShaderStageFlagBits shaderStage,
197                           const std::string &externalNamePart, const std::string &internalNamePart, uint32_t groupIndex)
198 {
199     char fullShaderName[40];
200     snprintf(fullShaderName, 40, externalNamePart.c_str(), internalNamePart.c_str());
201     std::string fsn = fullShaderName;
202     if (fsn.empty())
203         return false;
204     Move<VkShaderModule> shaderModule = createShaderModule(vkd, device, context.getBinaryCollection().get(fsn), 0);
205     if (*shaderModule == DE_NULL)
206         return false;
207     rayTracingPipeline.addShader(shaderStage, shaderModule, groupIndex);
208     return true;
209 }
210 
211 class GraphicsConfiguration : public TestConfiguration
212 {
213 public:
214     virtual ~GraphicsConfiguration();
215     void initConfiguration(Context &context, TestParams &testParams) override;
216     void fillCommandBuffer(
217         Context &context, TestParams &testParams, VkCommandBuffer commandBuffer,
218         const VkWriteDescriptorSetAccelerationStructureKHR &rayQueryAccelerationStructureWriteDescriptorSet,
219         const VkDescriptorImageInfo &resultImageInfo) override;
220     bool verifyImage(BufferWithMemory *resultBuffer, Context &context, TestParams &testParams) override;
221     VkFormat getResultImageFormat() override;
222     size_t getResultImageFormatSize() override;
223     VkClearValue getClearValue() override;
224 
225 protected:
226     Move<VkDescriptorSetLayout> descriptorSetLayout;
227     Move<VkDescriptorPool> descriptorPool;
228     Move<VkDescriptorSet> descriptorSet;
229     Move<VkPipelineLayout> pipelineLayout;
230     Move<VkRenderPass> renderPass;
231     Move<VkFramebuffer> framebuffer;
232     std::vector<de::SharedPtr<Move<VkShaderModule>>> shaderModules;
233     Move<VkPipeline> pipeline;
234     std::vector<tcu::Vec3> vertices;
235     Move<VkBuffer> vertexBuffer;
236     de::MovePtr<Allocation> vertexAlloc;
237 };
238 
~GraphicsConfiguration()239 GraphicsConfiguration::~GraphicsConfiguration()
240 {
241     shaderModules.clear();
242 }
243 
initConfiguration(Context & context,TestParams & testParams)244 void GraphicsConfiguration::initConfiguration(Context &context, TestParams &testParams)
245 {
246     const DeviceInterface &vkd      = context.getDeviceInterface();
247     const VkDevice device           = context.getDevice();
248     const uint32_t queueFamilyIndex = context.getUniversalQueueFamilyIndex();
249     Allocator &allocator            = context.getDefaultAllocator();
250 
251     descriptorSetLayout =
252         DescriptorSetLayoutBuilder()
253             .addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, VK_SHADER_STAGE_ALL_GRAPHICS)
254             .addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, VK_SHADER_STAGE_ALL_GRAPHICS)
255             .build(vkd, device);
256     descriptorPool = DescriptorPoolBuilder()
257                          .addType(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE)
258                          .addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR)
259                          .build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
260     descriptorSet  = makeDescriptorSet(vkd, device, *descriptorPool, *descriptorSetLayout);
261     pipelineLayout = makePipelineLayout(vkd, device, descriptorSetLayout.get());
262 
263     std::vector<std::vector<std::string>> rayQueryTestName(2);
264     rayQueryTestName[BTT_TRIANGLES].push_back("rq_gen_triangle");
265     rayQueryTestName[BTT_AABBS].push_back("rq_gen_aabb");
266     rayQueryTestName[BTT_TRIANGLES].push_back("rq_skip_triangle");
267     rayQueryTestName[BTT_AABBS].push_back("rq_skip_aabb");
268 
269     const std::map<ShaderSourceType, std::vector<std::string>> shaderNames = {
270         //idx: 0                1                2                3                4
271         //shader: vert, tesc, tese, geom, frag,
272         {SST_VERTEX_SHADER,
273          {
274              "vert_%s",
275              "",
276              "",
277              "",
278              "",
279          }},
280         {SST_TESSELATION_CONTROL_SHADER,
281          {
282              "vert",
283              "tesc_%s",
284              "tese",
285              "",
286              "",
287          }},
288         {SST_TESSELATION_EVALUATION_SHADER,
289          {
290              "vert",
291              "tesc",
292              "tese_%s",
293              "",
294              "",
295          }},
296         {SST_GEOMETRY_SHADER,
297          {
298              "vert",
299              "",
300              "",
301              "geom_%s",
302              "",
303          }},
304         {SST_FRAGMENT_SHADER,
305          {
306              "vert",
307              "",
308              "",
309              "",
310              "frag_%s",
311          }},
312     };
313 
314     auto shaderNameIt = shaderNames.find(testParams.shaderSourceType);
315     if (shaderNameIt == end(shaderNames))
316         TCU_THROW(InternalError, "Wrong shader source type");
317 
318     std::vector<VkPipelineShaderStageCreateInfo> shaderCreateInfos;
319     bool tescX, teseX, fragX;
320     registerShaderModule(vkd, device, context, shaderModules, shaderCreateInfos, VK_SHADER_STAGE_VERTEX_BIT,
321                          shaderNameIt->second[0], rayQueryTestName[testParams.bottomType][testParams.shaderTestType]);
322     tescX = registerShaderModule(vkd, device, context, shaderModules, shaderCreateInfos,
323                                  VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT, shaderNameIt->second[1],
324                                  rayQueryTestName[testParams.bottomType][testParams.shaderTestType]);
325     teseX = registerShaderModule(vkd, device, context, shaderModules, shaderCreateInfos,
326                                  VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT, shaderNameIt->second[2],
327                                  rayQueryTestName[testParams.bottomType][testParams.shaderTestType]);
328     registerShaderModule(vkd, device, context, shaderModules, shaderCreateInfos, VK_SHADER_STAGE_GEOMETRY_BIT,
329                          shaderNameIt->second[3], rayQueryTestName[testParams.bottomType][testParams.shaderTestType]);
330     fragX = registerShaderModule(vkd, device, context, shaderModules, shaderCreateInfos, VK_SHADER_STAGE_FRAGMENT_BIT,
331                                  shaderNameIt->second[4],
332                                  rayQueryTestName[testParams.bottomType][testParams.shaderTestType]);
333 
334     const vk::VkSubpassDescription subpassDesc = {
335         (vk::VkSubpassDescriptionFlags)0,
336         vk::VK_PIPELINE_BIND_POINT_GRAPHICS, // pipelineBindPoint
337         0u,                                  // inputCount
338         DE_NULL,                             // pInputAttachments
339         0u,                                  // colorCount
340         DE_NULL,                             // pColorAttachments
341         DE_NULL,                             // pResolveAttachments
342         DE_NULL,                             // depthStencilAttachment
343         0u,                                  // preserveCount
344         DE_NULL,                             // pPreserveAttachments
345     };
346     const vk::VkRenderPassCreateInfo renderPassParams = {
347         vk::VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO, // sType
348         DE_NULL,                                       // pNext
349         (vk::VkRenderPassCreateFlags)0,
350         0u,           // attachmentCount
351         DE_NULL,      // pAttachments
352         1u,           // subpassCount
353         &subpassDesc, // pSubpasses
354         0u,           // dependencyCount
355         DE_NULL,      // pDependencies
356     };
357 
358     renderPass = createRenderPass(vkd, device, &renderPassParams);
359 
360     const vk::VkFramebufferCreateInfo framebufferParams = {
361         vk::VK_STRUCTURE_TYPE_FRAMEBUFFER_CREATE_INFO, // sType
362         DE_NULL,                                       // pNext
363         (vk::VkFramebufferCreateFlags)0,
364         *renderPass,       // renderPass
365         0u,                // attachmentCount
366         DE_NULL,           // pAttachments
367         testParams.width,  // width
368         testParams.height, // height
369         1u,                // layers
370     };
371 
372     framebuffer = createFramebuffer(vkd, device, &framebufferParams);
373 
374     VkPrimitiveTopology testTopology = VK_PRIMITIVE_TOPOLOGY_TRIANGLE_STRIP;
375     tcu::Vec3 v0(1.0f, 1.0f, 0.0f);
376     tcu::Vec3 v1(float(testParams.width) - 1.0f, 1.0f, 0.0f);
377     tcu::Vec3 v2(1.0f, float(testParams.height) - 1.0f, 0.0f);
378     tcu::Vec3 v3(float(testParams.width) - 1.0f, float(testParams.height) - 1.0f, 0.0f);
379 
380     switch (testParams.shaderSourceType)
381     {
382     case SST_TESSELATION_CONTROL_SHADER:
383     case SST_TESSELATION_EVALUATION_SHADER:
384         testTopology = VK_PRIMITIVE_TOPOLOGY_PATCH_LIST;
385         vertices.push_back(v0);
386         vertices.push_back(v1);
387         vertices.push_back(v2);
388         vertices.push_back(v1);
389         vertices.push_back(v3);
390         vertices.push_back(v2);
391         break;
392     case SST_VERTEX_SHADER:
393     case SST_GEOMETRY_SHADER:
394         vertices.push_back(v0);
395         vertices.push_back(v1);
396         vertices.push_back(v2);
397         vertices.push_back(v3);
398         break;
399     case SST_FRAGMENT_SHADER:
400         vertices.push_back(tcu::Vec3(-1.0f, 1.0f, 0.0f));
401         vertices.push_back(tcu::Vec3(-1.0f, -1.0f, 0.0f));
402         vertices.push_back(tcu::Vec3(1.0f, 1.0f, 0.0f));
403         vertices.push_back(tcu::Vec3(1.0f, -1.0f, 0.0f));
404         break;
405     default:
406         TCU_THROW(InternalError, "Wrong shader source type");
407     }
408 
409     const VkVertexInputBindingDescription vertexInputBindingDescription = {
410         0u,                          // uint32_t binding;
411         sizeof(tcu::Vec3),           // uint32_t stride;
412         VK_VERTEX_INPUT_RATE_VERTEX, // VkVertexInputRate inputRate;
413     };
414 
415     const VkVertexInputAttributeDescription vertexInputAttributeDescription = {
416         0u,                         // uint32_t location;
417         0u,                         // uint32_t binding;
418         VK_FORMAT_R32G32B32_SFLOAT, // VkFormat format;
419         0u,                         // uint32_t offset;
420     };
421 
422     const VkPipelineVertexInputStateCreateInfo vertexInputStateCreateInfo = {
423         VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO, // VkStructureType sType;
424         DE_NULL,                                                   // const void* pNext;
425         (VkPipelineVertexInputStateCreateFlags)0,                  // VkPipelineVertexInputStateCreateFlags flags;
426         1u,                                                        // uint32_t vertexBindingDescriptionCount;
427         &vertexInputBindingDescription,  // const VkVertexInputBindingDescription* pVertexBindingDescriptions;
428         1u,                              // uint32_t vertexAttributeDescriptionCount;
429         &vertexInputAttributeDescription // const VkVertexInputAttributeDescription* pVertexAttributeDescriptions;
430     };
431 
432     const VkPipelineInputAssemblyStateCreateInfo inputAssemblyStateCreateInfo = {
433         VK_STRUCTURE_TYPE_PIPELINE_INPUT_ASSEMBLY_STATE_CREATE_INFO, // VkStructureType sType;
434         DE_NULL,                                                     // const void* pNext;
435         (VkPipelineInputAssemblyStateCreateFlags)0,                  // VkPipelineInputAssemblyStateCreateFlags flags;
436         testTopology,                                                // VkPrimitiveTopology topology;
437         VK_FALSE                                                     // VkBool32 primitiveRestartEnable;
438     };
439 
440     const VkPipelineTessellationStateCreateInfo tessellationStateCreateInfo = {
441         VK_STRUCTURE_TYPE_PIPELINE_TESSELLATION_STATE_CREATE_INFO, // VkStructureType sType;
442         DE_NULL,                                                   // const void* pNext;
443         VkPipelineTessellationStateCreateFlags(0u),                // VkPipelineTessellationStateCreateFlags flags;
444         3u                                                         // uint32_t patchControlPoints;
445     };
446 
447     VkViewport viewport = makeViewport(testParams.width, testParams.height);
448     VkRect2D scissor    = makeRect2D(testParams.width, testParams.height);
449 
450     const VkPipelineViewportStateCreateInfo viewportStateCreateInfo = {
451         VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO, // VkStructureType                                    sType
452         DE_NULL,                               // const void*                                        pNext
453         (VkPipelineViewportStateCreateFlags)0, // VkPipelineViewportStateCreateFlags                flags
454         1u,                                    // uint32_t                                            viewportCount
455         &viewport,                             // const VkViewport*                                pViewports
456         1u,                                    // uint32_t                                            scissorCount
457         &scissor                               // const VkRect2D*                                    pScissors
458     };
459 
460     const VkPipelineRasterizationStateCreateInfo rasterizationStateCreateInfo = {
461         VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_STATE_CREATE_INFO, // VkStructureType sType;
462         DE_NULL,                                                    // const void* pNext;
463         (VkPipelineRasterizationStateCreateFlags)0,                 // VkPipelineRasterizationStateCreateFlags flags;
464         VK_FALSE,                                                   // VkBool32 depthClampEnable;
465         fragX ? VK_FALSE : VK_TRUE,                                 // VkBool32 rasterizerDiscardEnable;
466         VK_POLYGON_MODE_FILL,                                       // VkPolygonMode polygonMode;
467         VK_CULL_MODE_NONE,                                          // VkCullModeFlags cullMode;
468         VK_FRONT_FACE_CLOCKWISE,                                    // VkFrontFace frontFace;
469         VK_FALSE,                                                   // VkBool32 depthBiasEnable;
470         0.0f,                                                       // float depthBiasConstantFactor;
471         0.0f,                                                       // float depthBiasClamp;
472         0.0f,                                                       // float depthBiasSlopeFactor;
473         1.0f                                                        // float lineWidth;
474     };
475 
476     const VkPipelineMultisampleStateCreateInfo multisampleStateCreateInfo = {
477         VK_STRUCTURE_TYPE_PIPELINE_MULTISAMPLE_STATE_CREATE_INFO, // VkStructureType sType;
478         DE_NULL,                                                  // const void* pNext;
479         (VkPipelineMultisampleStateCreateFlags)0,                 // VkPipelineMultisampleStateCreateFlags flags;
480         VK_SAMPLE_COUNT_1_BIT,                                    // VkSampleCountFlagBits rasterizationSamples;
481         VK_FALSE,                                                 // VkBool32 sampleShadingEnable;
482         0.0f,                                                     // float minSampleShading;
483         DE_NULL,                                                  // const VkSampleMask* pSampleMask;
484         VK_FALSE,                                                 // VkBool32 alphaToCoverageEnable;
485         VK_FALSE                                                  // VkBool32 alphaToOneEnable;
486     };
487 
488     const VkPipelineColorBlendStateCreateInfo colorBlendStateCreateInfo = {
489         VK_STRUCTURE_TYPE_PIPELINE_COLOR_BLEND_STATE_CREATE_INFO, // VkStructureType sType;
490         DE_NULL,                                                  // const void* pNext;
491         (VkPipelineColorBlendStateCreateFlags)0,                  // VkPipelineColorBlendStateCreateFlags flags;
492         false,                                                    // VkBool32 logicOpEnable;
493         VK_LOGIC_OP_CLEAR,                                        // VkLogicOp logicOp;
494         0,                                                        // uint32_t attachmentCount;
495         DE_NULL,                 // const VkPipelineColorBlendAttachmentState* pAttachments;
496         {1.0f, 1.0f, 1.0f, 1.0f} // float blendConstants[4];
497     };
498 
499     const VkGraphicsPipelineCreateInfo graphicsPipelineCreateInfo = {
500         VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO, // VkStructureType sType;
501         DE_NULL,                                         // const void* pNext;
502         (VkPipelineCreateFlags)0,                        // VkPipelineCreateFlags flags;
503         static_cast<uint32_t>(shaderCreateInfos.size()), // uint32_t stageCount;
504         shaderCreateInfos.data(),                        // const VkPipelineShaderStageCreateInfo* pStages;
505         &vertexInputStateCreateInfo,   // const VkPipelineVertexInputStateCreateInfo* pVertexInputState;
506         &inputAssemblyStateCreateInfo, // const VkPipelineInputAssemblyStateCreateInfo* pInputAssemblyState;
507         (tescX || teseX) ? &tessellationStateCreateInfo :
508                            DE_NULL,                 // const VkPipelineTessellationStateCreateInfo* pTessellationState;
509         fragX ? &viewportStateCreateInfo : DE_NULL, // const VkPipelineViewportStateCreateInfo* pViewportState;
510         &rasterizationStateCreateInfo, // const VkPipelineRasterizationStateCreateInfo* pRasterizationState;
511         fragX ? &multisampleStateCreateInfo : DE_NULL, // const VkPipelineMultisampleStateCreateInfo* pMultisampleState;
512         DE_NULL, // const VkPipelineDepthStencilStateCreateInfo* pDepthStencilState;
513         fragX ? &colorBlendStateCreateInfo : DE_NULL, // const VkPipelineColorBlendStateCreateInfo* pColorBlendState;
514         DE_NULL,                                      // const VkPipelineDynamicStateCreateInfo* pDynamicState;
515         pipelineLayout.get(),                         // VkPipelineLayout layout;
516         renderPass.get(),                             // VkRenderPass renderPass;
517         0u,                                           // uint32_t subpass;
518         DE_NULL,                                      // VkPipeline basePipelineHandle;
519         0                                             // int basePipelineIndex;
520     };
521 
522     pipeline = createGraphicsPipeline(vkd, device, DE_NULL, &graphicsPipelineCreateInfo);
523 
524     const VkBufferCreateInfo vertexBufferParams = {
525         VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO,                                 // VkStructureType sType;
526         DE_NULL,                                                              // const void* pNext;
527         0u,                                                                   // VkBufferCreateFlags flags;
528         VkDeviceSize(sizeof(tcu::Vec3) * vertices.size()),                    // VkDeviceSize size;
529         VK_BUFFER_USAGE_VERTEX_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT, // VkBufferUsageFlags usage;
530         VK_SHARING_MODE_EXCLUSIVE,                                            // VkSharingMode sharingMode;
531         1u,                                                                   // uint32_t queueFamilyIndexCount;
532         &queueFamilyIndex                                                     // const uint32_t* pQueueFamilyIndices;
533     };
534 
535     vertexBuffer = createBuffer(vkd, device, &vertexBufferParams);
536     vertexAlloc =
537         allocator.allocate(getBufferMemoryRequirements(vkd, device, *vertexBuffer), MemoryRequirement::HostVisible);
538     VK_CHECK(vkd.bindBufferMemory(device, *vertexBuffer, vertexAlloc->getMemory(), vertexAlloc->getOffset()));
539 
540     // Upload vertex data
541     deMemcpy(vertexAlloc->getHostPtr(), vertices.data(), vertices.size() * sizeof(tcu::Vec3));
542     flushAlloc(vkd, device, *vertexAlloc);
543 }
544 
fillCommandBuffer(Context & context,TestParams & testParams,VkCommandBuffer commandBuffer,const VkWriteDescriptorSetAccelerationStructureKHR & rayQueryAccelerationStructureWriteDescriptorSet,const VkDescriptorImageInfo & resultImageInfo)545 void GraphicsConfiguration::fillCommandBuffer(
546     Context &context, TestParams &testParams, VkCommandBuffer commandBuffer,
547     const VkWriteDescriptorSetAccelerationStructureKHR &rayQueryAccelerationStructureWriteDescriptorSet,
548     const VkDescriptorImageInfo &resultImageInfo)
549 {
550     const DeviceInterface &vkd = context.getDeviceInterface();
551     const VkDevice device      = context.getDevice();
552 
553     DescriptorSetUpdateBuilder()
554         .writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(0u),
555                      VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, &resultImageInfo)
556         .writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(1u),
557                      VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &rayQueryAccelerationStructureWriteDescriptorSet)
558         .update(vkd, device);
559 
560     const VkRenderPassBeginInfo renderPassBeginInfo = {
561         VK_STRUCTURE_TYPE_RENDER_PASS_BEGIN_INFO,        // VkStructureType sType;
562         DE_NULL,                                         // const void* pNext;
563         *renderPass,                                     // VkRenderPass renderPass;
564         *framebuffer,                                    // VkFramebuffer framebuffer;
565         makeRect2D(testParams.width, testParams.height), // VkRect2D renderArea;
566         0u,                                              // uint32_t clearValueCount;
567         DE_NULL                                          // const VkClearValue* pClearValues;
568     };
569     VkDeviceSize vertexBufferOffset = 0u;
570 
571     vkd.cmdBeginRenderPass(commandBuffer, &renderPassBeginInfo, VK_SUBPASS_CONTENTS_INLINE);
572     vkd.cmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, *pipeline);
573     vkd.cmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, *pipelineLayout, 0u, 1u,
574                               &descriptorSet.get(), 0u, DE_NULL);
575     vkd.cmdBindVertexBuffers(commandBuffer, 0, 1, &vertexBuffer.get(), &vertexBufferOffset);
576     vkd.cmdDraw(commandBuffer, uint32_t(vertices.size()), 1, 0, 0);
577     vkd.cmdEndRenderPass(commandBuffer);
578 }
579 
verifyImage(BufferWithMemory * resultBuffer,Context & context,TestParams & testParams)580 bool GraphicsConfiguration::verifyImage(BufferWithMemory *resultBuffer, Context &context, TestParams &testParams)
581 {
582     // create result image
583     tcu::TextureFormat imageFormat = vk::mapVkFormat(getResultImageFormat());
584     tcu::ConstPixelBufferAccess resultAccess(imageFormat, testParams.width, testParams.height, 2,
585                                              resultBuffer->getAllocation().getHostPtr());
586 
587     // create reference image
588     std::vector<uint32_t> reference(testParams.width * testParams.height * 2);
589     tcu::PixelBufferAccess referenceAccess(imageFormat, testParams.width, testParams.height, 2, reference.data());
590 
591     tcu::UVec4 rqValue0, rqValue1;
592     switch (testParams.shaderTestType)
593     {
594     case STT_GENERATE_INTERSECTION:
595         switch (testParams.bottomType)
596         {
597         case BTT_TRIANGLES:
598             rqValue0 = tcu::UVec4(1, 0, 0, 0);
599             rqValue1 = tcu::UVec4(1, 0, 0, 0);
600             break;
601         case BTT_AABBS:
602             rqValue0 = tcu::UVec4(2, 0, 0, 0);
603             rqValue1 = tcu::UVec4(1, 0, 0, 0);
604             break;
605         default:
606             TCU_THROW(InternalError, "Wrong bottom test type");
607         }
608         break;
609     case STT_SKIP_INTERSECTION:
610         switch (testParams.bottomType)
611         {
612         case BTT_TRIANGLES:
613             rqValue0 = tcu::UVec4(0, 0, 0, 0);
614             rqValue1 = tcu::UVec4(1, 0, 0, 0);
615             break;
616         case BTT_AABBS:
617             rqValue0 = tcu::UVec4(0, 0, 0, 0);
618             rqValue1 = tcu::UVec4(1, 0, 0, 0);
619             break;
620         default:
621             TCU_THROW(InternalError, "Wrong bottom test type");
622         }
623         break;
624     default:
625         TCU_THROW(InternalError, "Wrong shader test type");
626     }
627 
628     std::vector<std::vector<uint32_t>> primitives = {{0, 1, 2}, {1, 3, 2}};
629 
630     tcu::UVec4 clearValue, missValue, hitValue0, hitValue1;
631     hitValue0  = rqValue0;
632     hitValue1  = rqValue1;
633     missValue  = tcu::UVec4(0, 0, 0, 0);
634     clearValue = tcu::UVec4(0xFF, 0, 0, 0);
635 
636     switch (testParams.shaderSourceType)
637     {
638     case SST_VERTEX_SHADER:
639         tcu::clear(referenceAccess, clearValue);
640         for (uint32_t vertexNdx = 0; vertexNdx < 4; ++vertexNdx)
641         {
642             if (vertexNdx == 0)
643             {
644                 referenceAccess.setPixel(hitValue0, vertexNdx, 0, 0);
645                 referenceAccess.setPixel(hitValue1, vertexNdx, 0, 1);
646             }
647             else
648             {
649                 referenceAccess.setPixel(missValue, vertexNdx, 0, 0);
650                 referenceAccess.setPixel(missValue, vertexNdx, 0, 1);
651             }
652         }
653         break;
654     case SST_TESSELATION_EVALUATION_SHADER:
655     case SST_TESSELATION_CONTROL_SHADER:
656     case SST_GEOMETRY_SHADER:
657         tcu::clear(referenceAccess, clearValue);
658         for (uint32_t primitiveNdx = 0; primitiveNdx < primitives.size(); ++primitiveNdx)
659             for (uint32_t vertexNdx = 0; vertexNdx < 3; ++vertexNdx)
660             {
661                 uint32_t vNdx = primitives[primitiveNdx][vertexNdx];
662                 if (vNdx == 0)
663                 {
664                     referenceAccess.setPixel(hitValue0, primitiveNdx, vertexNdx, 0);
665                     referenceAccess.setPixel(hitValue1, primitiveNdx, vertexNdx, 1);
666                 }
667                 else
668                 {
669                     referenceAccess.setPixel(missValue, primitiveNdx, vertexNdx, 0);
670                     referenceAccess.setPixel(missValue, primitiveNdx, vertexNdx, 1);
671                 }
672             }
673         break;
674     case SST_FRAGMENT_SHADER:
675         tcu::clear(referenceAccess, missValue);
676         for (uint32_t y = 1; y < testParams.height - 1; ++y)
677             for (uint32_t x = 1; x < testParams.width - 1; ++x)
678             {
679                 referenceAccess.setPixel(hitValue0, x, y, 0);
680                 referenceAccess.setPixel(hitValue1, x, y, 1);
681             }
682         break;
683     default:
684         TCU_THROW(InternalError, "Wrong shader source type");
685     }
686 
687     // compare result and reference
688     return tcu::intThresholdCompare(context.getTestContext().getLog(), "Result comparison", "", referenceAccess,
689                                     resultAccess, tcu::UVec4(0), tcu::COMPARE_LOG_RESULT);
690 }
691 
getResultImageFormat()692 VkFormat GraphicsConfiguration::getResultImageFormat()
693 {
694     return VK_FORMAT_R32_UINT;
695 }
696 
getResultImageFormatSize()697 size_t GraphicsConfiguration::getResultImageFormatSize()
698 {
699     return sizeof(uint32_t);
700 }
701 
getClearValue()702 VkClearValue GraphicsConfiguration::getClearValue()
703 {
704     return makeClearValueColorU32(0xFF, 0u, 0u, 0u);
705 }
706 
707 class ComputeConfiguration : public TestConfiguration
708 {
709 public:
710     virtual ~ComputeConfiguration();
711     void initConfiguration(Context &context, TestParams &testParams) override;
712     void fillCommandBuffer(
713         Context &context, TestParams &testParams, VkCommandBuffer commandBuffer,
714         const VkWriteDescriptorSetAccelerationStructureKHR &rayQueryAccelerationStructureWriteDescriptorSet,
715         const VkDescriptorImageInfo &resultImageInfo) override;
716     bool verifyImage(BufferWithMemory *resultBuffer, Context &context, TestParams &testParams) override;
717     VkFormat getResultImageFormat() override;
718     size_t getResultImageFormatSize() override;
719     VkClearValue getClearValue() override;
720 
721 protected:
722     Move<VkDescriptorSetLayout> descriptorSetLayout;
723     Move<VkDescriptorPool> descriptorPool;
724     Move<VkDescriptorSet> descriptorSet;
725     Move<VkPipelineLayout> pipelineLayout;
726     Move<VkShaderModule> shaderModule;
727     Move<VkPipeline> pipeline;
728 };
729 
~ComputeConfiguration()730 ComputeConfiguration::~ComputeConfiguration()
731 {
732 }
733 
initConfiguration(Context & context,TestParams & testParams)734 void ComputeConfiguration::initConfiguration(Context &context, TestParams &testParams)
735 {
736     const DeviceInterface &vkd = context.getDeviceInterface();
737     const VkDevice device      = context.getDevice();
738 
739     descriptorSetLayout =
740         DescriptorSetLayoutBuilder()
741             .addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, VK_SHADER_STAGE_COMPUTE_BIT)
742             .addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, VK_SHADER_STAGE_COMPUTE_BIT)
743             .build(vkd, device);
744     descriptorPool = DescriptorPoolBuilder()
745                          .addType(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE)
746                          .addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR)
747                          .build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
748     descriptorSet  = makeDescriptorSet(vkd, device, *descriptorPool, *descriptorSetLayout);
749     pipelineLayout = makePipelineLayout(vkd, device, descriptorSetLayout.get());
750 
751     std::vector<std::vector<std::string>> rayQueryTestName(2);
752     rayQueryTestName[BTT_TRIANGLES].push_back("comp_rq_gen_triangle");
753     rayQueryTestName[BTT_AABBS].push_back("comp_rq_gen_aabb");
754     rayQueryTestName[BTT_TRIANGLES].push_back("comp_rq_skip_triangle");
755     rayQueryTestName[BTT_AABBS].push_back("comp_rq_skip_aabb");
756 
757     shaderModule = createShaderModule(
758         vkd, device,
759         context.getBinaryCollection().get(rayQueryTestName[testParams.bottomType][testParams.shaderTestType]), 0u);
760     const VkPipelineShaderStageCreateInfo pipelineShaderStageParams = {
761         VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // VkStructureType sType;
762         DE_NULL,                                             // const void* pNext;
763         0u,                                                  // VkPipelineShaderStageCreateFlags flags;
764         VK_SHADER_STAGE_COMPUTE_BIT,                         // VkShaderStageFlagBits stage;
765         *shaderModule,                                       // VkShaderModule module;
766         "main",                                              // const char* pName;
767         DE_NULL,                                             // const VkSpecializationInfo* pSpecializationInfo;
768     };
769     const VkComputePipelineCreateInfo pipelineCreateInfo = {
770         VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, // VkStructureType sType;
771         DE_NULL,                                        // const void* pNext;
772         0u,                                             // VkPipelineCreateFlags flags;
773         pipelineShaderStageParams,                      // VkPipelineShaderStageCreateInfo stage;
774         *pipelineLayout,                                // VkPipelineLayout layout;
775         DE_NULL,                                        // VkPipeline basePipelineHandle;
776         0,                                              // int32_t basePipelineIndex;
777     };
778 
779     pipeline = createComputePipeline(vkd, device, DE_NULL, &pipelineCreateInfo);
780 }
781 
fillCommandBuffer(Context & context,TestParams & testParams,VkCommandBuffer commandBuffer,const VkWriteDescriptorSetAccelerationStructureKHR & rayQueryAccelerationStructureWriteDescriptorSet,const VkDescriptorImageInfo & resultImageInfo)782 void ComputeConfiguration::fillCommandBuffer(
783     Context &context, TestParams &testParams, VkCommandBuffer commandBuffer,
784     const VkWriteDescriptorSetAccelerationStructureKHR &rayQueryAccelerationStructureWriteDescriptorSet,
785     const VkDescriptorImageInfo &resultImageInfo)
786 {
787     const DeviceInterface &vkd = context.getDeviceInterface();
788     const VkDevice device      = context.getDevice();
789 
790     DescriptorSetUpdateBuilder()
791         .writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(0u),
792                      VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, &resultImageInfo)
793         .writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(1u),
794                      VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &rayQueryAccelerationStructureWriteDescriptorSet)
795         .update(vkd, device);
796 
797     vkd.cmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, *pipeline);
798 
799     vkd.cmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, *pipelineLayout, 0u, 1u,
800                               &descriptorSet.get(), 0u, DE_NULL);
801 
802     vkd.cmdDispatch(commandBuffer, testParams.width, testParams.height, 1);
803 }
804 
verifyImage(BufferWithMemory * resultBuffer,Context & context,TestParams & testParams)805 bool ComputeConfiguration::verifyImage(BufferWithMemory *resultBuffer, Context &context, TestParams &testParams)
806 {
807     // create result image
808     tcu::TextureFormat imageFormat = vk::mapVkFormat(getResultImageFormat());
809     tcu::ConstPixelBufferAccess resultAccess(imageFormat, testParams.width, testParams.height, 2,
810                                              resultBuffer->getAllocation().getHostPtr());
811 
812     // create reference image
813     std::vector<uint32_t> reference(testParams.width * testParams.height * 2);
814     tcu::PixelBufferAccess referenceAccess(imageFormat, testParams.width, testParams.height, 2, reference.data());
815 
816     tcu::UVec4 rqValue0, rqValue1;
817     switch (testParams.shaderTestType)
818     {
819     case STT_GENERATE_INTERSECTION:
820         switch (testParams.bottomType)
821         {
822         case BTT_TRIANGLES:
823             rqValue0 = tcu::UVec4(1, 0, 0, 0);
824             rqValue1 = tcu::UVec4(1, 0, 0, 0);
825             break;
826         case BTT_AABBS:
827             rqValue0 = tcu::UVec4(2, 0, 0, 0);
828             rqValue1 = tcu::UVec4(1, 0, 0, 0);
829             break;
830         default:
831             TCU_THROW(InternalError, "Wrong bottom test type");
832         }
833         break;
834     case STT_SKIP_INTERSECTION:
835         switch (testParams.bottomType)
836         {
837         case BTT_TRIANGLES:
838             rqValue0 = tcu::UVec4(0, 0, 0, 0);
839             rqValue1 = tcu::UVec4(1, 0, 0, 0);
840             break;
841         case BTT_AABBS:
842             rqValue0 = tcu::UVec4(0, 0, 0, 0);
843             rqValue1 = tcu::UVec4(1, 0, 0, 0);
844             break;
845         default:
846             TCU_THROW(InternalError, "Wrong bottom test type");
847         }
848         break;
849     default:
850         TCU_THROW(InternalError, "Wrong shader test type");
851     }
852 
853     tcu::UVec4 missValue0, missValue1, hitValue0, hitValue1;
854     hitValue0  = rqValue0;
855     hitValue1  = rqValue1;
856     missValue0 = tcu::UVec4(0, 0, 0, 0);
857     missValue1 = tcu::UVec4(0, 0, 0, 0);
858 
859     tcu::clear(referenceAccess, missValue0);
860     for (uint32_t y = 0; y < testParams.height; ++y)
861         for (uint32_t x = 0; x < testParams.width; ++x)
862             referenceAccess.setPixel(missValue1, x, y, 1);
863 
864     for (uint32_t y = 1; y < testParams.height - 1; ++y)
865         for (uint32_t x = 1; x < testParams.width - 1; ++x)
866         {
867             referenceAccess.setPixel(hitValue0, x, y, 0);
868             referenceAccess.setPixel(hitValue1, x, y, 1);
869         }
870 
871     // compare result and reference
872     return tcu::intThresholdCompare(context.getTestContext().getLog(), "Result comparison", "", referenceAccess,
873                                     resultAccess, tcu::UVec4(0), tcu::COMPARE_LOG_RESULT);
874 }
875 
getResultImageFormat()876 VkFormat ComputeConfiguration::getResultImageFormat()
877 {
878     return VK_FORMAT_R32_UINT;
879 }
880 
getResultImageFormatSize()881 size_t ComputeConfiguration::getResultImageFormatSize()
882 {
883     return sizeof(uint32_t);
884 }
885 
getClearValue()886 VkClearValue ComputeConfiguration::getClearValue()
887 {
888     return makeClearValueColorU32(0xFF, 0u, 0u, 0u);
889 }
890 
891 class RayTracingConfiguration : public TestConfiguration
892 {
893 public:
894     virtual ~RayTracingConfiguration();
895     void initConfiguration(Context &context, TestParams &testParams) override;
896     void fillCommandBuffer(
897         Context &context, TestParams &testParams, VkCommandBuffer commandBuffer,
898         const VkWriteDescriptorSetAccelerationStructureKHR &rayQueryAccelerationStructureWriteDescriptorSet,
899         const VkDescriptorImageInfo &resultImageInfo) override;
900     bool verifyImage(BufferWithMemory *resultBuffer, Context &context, TestParams &testParams) override;
901     VkFormat getResultImageFormat() override;
902     size_t getResultImageFormatSize() override;
903     VkClearValue getClearValue() override;
904 
905 protected:
906     Move<VkDescriptorSetLayout> descriptorSetLayout;
907     Move<VkDescriptorPool> descriptorPool;
908     Move<VkDescriptorSet> descriptorSet;
909     Move<VkPipelineLayout> pipelineLayout;
910 
911     de::MovePtr<RayTracingPipeline> rayTracingPipeline;
912     Move<VkPipeline> rtPipeline;
913 
914     de::MovePtr<BufferWithMemory> raygenShaderBindingTable;
915     de::MovePtr<BufferWithMemory> hitShaderBindingTable;
916     de::MovePtr<BufferWithMemory> missShaderBindingTable;
917     de::MovePtr<BufferWithMemory> callableShaderBindingTable;
918 
919     std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> bottomLevelAccelerationStructures;
920     de::MovePtr<TopLevelAccelerationStructure> topLevelAccelerationStructure;
921 };
922 
~RayTracingConfiguration()923 RayTracingConfiguration::~RayTracingConfiguration()
924 {
925 }
926 
initConfiguration(Context & context,TestParams & testParams)927 void RayTracingConfiguration::initConfiguration(Context &context, TestParams &testParams)
928 {
929     const InstanceInterface &vki          = context.getInstanceInterface();
930     const DeviceInterface &vkd            = context.getDeviceInterface();
931     const VkDevice device                 = context.getDevice();
932     const VkPhysicalDevice physicalDevice = context.getPhysicalDevice();
933     Allocator &allocator                  = context.getDefaultAllocator();
934 
935     descriptorSetLayout = DescriptorSetLayoutBuilder()
936                               .addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, ALL_RAY_TRACING_STAGES)
937                               .addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, ALL_RAY_TRACING_STAGES)
938                               .addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, ALL_RAY_TRACING_STAGES)
939                               .build(vkd, device);
940     descriptorPool = DescriptorPoolBuilder()
941                          .addType(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE)
942                          .addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR)
943                          .addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR)
944                          .build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
945     descriptorSet  = makeDescriptorSet(vkd, device, *descriptorPool, *descriptorSetLayout);
946     pipelineLayout = makePipelineLayout(vkd, device, descriptorSetLayout.get());
947 
948     rayTracingPipeline = de::newMovePtr<RayTracingPipeline>();
949 
950     const std::map<ShaderSourceType, std::vector<std::string>> shaderNames = {
951         //idx: 0                1                2                3                4                5
952         //shader: rgen, isect, ahit, chit, miss, call
953         //group: 0                1                1                1                2                3
954         {SST_RAY_GENERATION_SHADER, {"rgen_%s", "", "", "", "", ""}},
955         {SST_INTERSECTION_SHADER, {"rgen", "isect_%s", "", "chit_isect", "miss", ""}},
956         {SST_ANY_HIT_SHADER, {"rgen", "isect", "ahit_%s", "", "miss", ""}},
957         {SST_CLOSEST_HIT_SHADER, {"rgen", "isect", "", "chit_%s", "miss", ""}},
958         {SST_MISS_SHADER, {"rgen", "isect", "", "chit", "miss_%s", ""}},
959         {SST_CALLABLE_SHADER, {"rgen_call", "", "", "chit", "miss", "call_%s"}},
960     };
961 
962     std::vector<std::vector<std::string>> rayQueryTestName(2);
963     rayQueryTestName[BTT_TRIANGLES].push_back("rq_gen_triangle");
964     rayQueryTestName[BTT_AABBS].push_back("rq_gen_aabb");
965     rayQueryTestName[BTT_TRIANGLES].push_back("rq_skip_triangle");
966     rayQueryTestName[BTT_AABBS].push_back("rq_skip_aabb");
967 
968     auto shaderNameIt = shaderNames.find(testParams.shaderSourceType);
969     if (shaderNameIt == end(shaderNames))
970         TCU_THROW(InternalError, "Wrong shader source type");
971 
972     bool rgenX, isectX, ahitX, chitX, missX, callX;
973     rgenX = registerShaderModule(vkd, device, context, *rayTracingPipeline, VK_SHADER_STAGE_RAYGEN_BIT_KHR,
974                                  shaderNameIt->second[0],
975                                  rayQueryTestName[testParams.bottomType][testParams.shaderTestType], 0);
976     if (testParams.shaderSourceType == SST_INTERSECTION_SHADER)
977         isectX = registerShaderModule(vkd, device, context, *rayTracingPipeline, VK_SHADER_STAGE_INTERSECTION_BIT_KHR,
978                                       shaderNameIt->second[1],
979                                       rayQueryTestName[testParams.bottomType][testParams.shaderTestType], 1);
980     else
981         isectX = false;
982     ahitX     = registerShaderModule(vkd, device, context, *rayTracingPipeline, VK_SHADER_STAGE_ANY_HIT_BIT_KHR,
983                                      shaderNameIt->second[2],
984                                      rayQueryTestName[testParams.bottomType][testParams.shaderTestType], 1);
985     chitX     = registerShaderModule(vkd, device, context, *rayTracingPipeline, VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,
986                                      shaderNameIt->second[3],
987                                      rayQueryTestName[testParams.bottomType][testParams.shaderTestType], 1);
988     missX     = registerShaderModule(vkd, device, context, *rayTracingPipeline, VK_SHADER_STAGE_MISS_BIT_KHR,
989                                      shaderNameIt->second[4],
990                                      rayQueryTestName[testParams.bottomType][testParams.shaderTestType], 2);
991     callX     = registerShaderModule(vkd, device, context, *rayTracingPipeline, VK_SHADER_STAGE_CALLABLE_BIT_KHR,
992                                      shaderNameIt->second[5],
993                                      rayQueryTestName[testParams.bottomType][testParams.shaderTestType], 3);
994     bool hitX = isectX || ahitX || chitX;
995 
996     rtPipeline = rayTracingPipeline->createPipeline(vkd, device, *pipelineLayout);
997 
998     uint32_t shaderGroupHandleSize    = getShaderGroupHandleSize(vki, physicalDevice);
999     uint32_t shaderGroupBaseAlignment = getShaderGroupBaseAlignment(vki, physicalDevice);
1000 
1001     if (rgenX)
1002         raygenShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
1003             vkd, device, *rtPipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
1004     if (hitX)
1005         hitShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
1006             vkd, device, *rtPipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
1007     if (missX)
1008         missShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
1009             vkd, device, *rtPipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
1010     if (callX)
1011         callableShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
1012             vkd, device, *rtPipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 3, 1);
1013 }
1014 
fillCommandBuffer(Context & context,TestParams & testParams,VkCommandBuffer commandBuffer,const VkWriteDescriptorSetAccelerationStructureKHR & rayQueryAccelerationStructureWriteDescriptorSet,const VkDescriptorImageInfo & resultImageInfo)1015 void RayTracingConfiguration::fillCommandBuffer(
1016     Context &context, TestParams &testParams, VkCommandBuffer commandBuffer,
1017     const VkWriteDescriptorSetAccelerationStructureKHR &rayQueryAccelerationStructureWriteDescriptorSet,
1018     const VkDescriptorImageInfo &resultImageInfo)
1019 {
1020     const InstanceInterface &vki          = context.getInstanceInterface();
1021     const DeviceInterface &vkd            = context.getDeviceInterface();
1022     const VkDevice device                 = context.getDevice();
1023     const VkPhysicalDevice physicalDevice = context.getPhysicalDevice();
1024     Allocator &allocator                  = context.getDefaultAllocator();
1025 
1026     {
1027         de::MovePtr<BottomLevelAccelerationStructure> bottomLevelAccelerationStructure =
1028             makeBottomLevelAccelerationStructure();
1029         bottomLevelAccelerationStructure->setGeometryCount(1);
1030 
1031         de::SharedPtr<RaytracedGeometryBase> geometry;
1032         if (testParams.shaderSourceType != SST_INTERSECTION_SHADER)
1033         {
1034             tcu::Vec3 v0(0.0f, 0.5f * float(testParams.height), 0.0f);
1035             tcu::Vec3 v1(0.0f, 0.0f, 0.0f);
1036             tcu::Vec3 v2(float(testParams.width), 0.5f * float(testParams.height), 0.0f);
1037             tcu::Vec3 v3(float(testParams.width), 0.0f, 0.0f);
1038 
1039             geometry = makeRaytracedGeometry(VK_GEOMETRY_TYPE_TRIANGLES_KHR, VK_FORMAT_R32G32B32_SFLOAT,
1040                                              VK_INDEX_TYPE_NONE_KHR);
1041             geometry->addVertex(v0);
1042             geometry->addVertex(v1);
1043             geometry->addVertex(v2);
1044             geometry->addVertex(v2);
1045             geometry->addVertex(v1);
1046             geometry->addVertex(v3);
1047         }
1048         else // testParams.shaderSourceType == SST_INTERSECTION_SHADER
1049         {
1050             tcu::Vec3 v0(0.0f, 0.0f, -0.1f);
1051             tcu::Vec3 v1(float(testParams.width), 0.5f * float(testParams.height), 0.1f);
1052 
1053             geometry =
1054                 makeRaytracedGeometry(VK_GEOMETRY_TYPE_AABBS_KHR, VK_FORMAT_R32G32B32_SFLOAT, VK_INDEX_TYPE_NONE_KHR);
1055             geometry->addVertex(v0);
1056             geometry->addVertex(v1);
1057         }
1058         bottomLevelAccelerationStructure->addGeometry(geometry);
1059         bottomLevelAccelerationStructures.push_back(
1060             de::SharedPtr<BottomLevelAccelerationStructure>(bottomLevelAccelerationStructure.release()));
1061 
1062         for (auto &blas : bottomLevelAccelerationStructures)
1063             blas->createAndBuild(vkd, device, commandBuffer, allocator);
1064     }
1065 
1066     topLevelAccelerationStructure = makeTopLevelAccelerationStructure();
1067     topLevelAccelerationStructure->setInstanceCount(1);
1068     topLevelAccelerationStructure->addInstance(bottomLevelAccelerationStructures[0]);
1069     topLevelAccelerationStructure->createAndBuild(vkd, device, commandBuffer, allocator);
1070 
1071     const TopLevelAccelerationStructure *topLevelAccelerationStructurePtr = topLevelAccelerationStructure.get();
1072     VkWriteDescriptorSetAccelerationStructureKHR accelerationStructureWriteDescriptorSet = {
1073         VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR, //  VkStructureType sType;
1074         DE_NULL,                                                           //  const void* pNext;
1075         1u,                                                                //  uint32_t accelerationStructureCount;
1076         topLevelAccelerationStructurePtr->getPtr(), //  const VkAccelerationStructureKHR* pAccelerationStructures;
1077     };
1078 
1079     DescriptorSetUpdateBuilder()
1080         .writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(0u),
1081                      VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, &resultImageInfo)
1082         .writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(1u),
1083                      VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &accelerationStructureWriteDescriptorSet)
1084         .writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(2u),
1085                      VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &rayQueryAccelerationStructureWriteDescriptorSet)
1086         .update(vkd, device);
1087 
1088     vkd.cmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipelineLayout, 0, 1,
1089                               &descriptorSet.get(), 0, DE_NULL);
1090 
1091     vkd.cmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *rtPipeline);
1092 
1093     uint32_t shaderGroupHandleSize = getShaderGroupHandleSize(vki, physicalDevice);
1094     VkStridedDeviceAddressRegionKHR raygenShaderBindingTableRegion =
1095         raygenShaderBindingTable.get() != DE_NULL ?
1096             makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0),
1097                                               shaderGroupHandleSize, shaderGroupHandleSize) :
1098             makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
1099     VkStridedDeviceAddressRegionKHR hitShaderBindingTableRegion =
1100         hitShaderBindingTable.get() != DE_NULL ?
1101             makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0),
1102                                               shaderGroupHandleSize, shaderGroupHandleSize) :
1103             makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
1104     VkStridedDeviceAddressRegionKHR missShaderBindingTableRegion =
1105         missShaderBindingTable.get() != DE_NULL ?
1106             makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missShaderBindingTable->get(), 0),
1107                                               shaderGroupHandleSize, shaderGroupHandleSize) :
1108             makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
1109     VkStridedDeviceAddressRegionKHR callableShaderBindingTableRegion =
1110         callableShaderBindingTable.get() != DE_NULL ?
1111             makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0),
1112                                               shaderGroupHandleSize, shaderGroupHandleSize) :
1113             makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
1114 
1115     cmdTraceRays(vkd, commandBuffer, &raygenShaderBindingTableRegion, &missShaderBindingTableRegion,
1116                  &hitShaderBindingTableRegion, &callableShaderBindingTableRegion, testParams.width, testParams.height,
1117                  1);
1118 }
1119 
verifyImage(BufferWithMemory * resultBuffer,Context & context,TestParams & testParams)1120 bool RayTracingConfiguration::verifyImage(BufferWithMemory *resultBuffer, Context &context, TestParams &testParams)
1121 {
1122     // create result image
1123     tcu::TextureFormat imageFormat = vk::mapVkFormat(getResultImageFormat());
1124     tcu::ConstPixelBufferAccess resultAccess(imageFormat, testParams.width, testParams.height, 2,
1125                                              resultBuffer->getAllocation().getHostPtr());
1126 
1127     // create reference image
1128     std::vector<uint32_t> reference(testParams.width * testParams.height * 2);
1129     tcu::PixelBufferAccess referenceAccess(imageFormat, testParams.width, testParams.height, 2, reference.data());
1130 
1131     tcu::UVec4 rqValue0, rqValue1;
1132     switch (testParams.shaderTestType)
1133     {
1134     case STT_GENERATE_INTERSECTION:
1135         switch (testParams.bottomType)
1136         {
1137         case BTT_TRIANGLES:
1138             rqValue0 = tcu::UVec4(1, 0, 0, 0);
1139             rqValue1 = tcu::UVec4(1, 0, 0, 0);
1140             break;
1141         case BTT_AABBS:
1142             rqValue0 = tcu::UVec4(2, 0, 0, 0);
1143             rqValue1 = tcu::UVec4(1, 0, 0, 0);
1144             break;
1145         default:
1146             TCU_THROW(InternalError, "Wrong bottom test type");
1147         }
1148         break;
1149     case STT_SKIP_INTERSECTION:
1150         switch (testParams.bottomType)
1151         {
1152         case BTT_TRIANGLES:
1153             rqValue0 = tcu::UVec4(0, 0, 0, 0);
1154             rqValue1 = tcu::UVec4(1, 0, 0, 0);
1155             break;
1156         case BTT_AABBS:
1157             rqValue0 = tcu::UVec4(0, 0, 0, 0);
1158             rqValue1 = tcu::UVec4(1, 0, 0, 0);
1159             break;
1160         default:
1161             TCU_THROW(InternalError, "Wrong bottom test type");
1162         }
1163         break;
1164     default:
1165         TCU_THROW(InternalError, "Wrong shader test type");
1166     }
1167 
1168     std::array<tcu::UVec4, 2> missMissValue, missHitValue, hitMissValue, hitHitValue;
1169     switch (testParams.shaderSourceType)
1170     {
1171     case SST_RAY_GENERATION_SHADER:
1172         missMissValue = {{tcu::UVec4(0, 0, 0, 0), tcu::UVec4(0, 0, 0, 0)}};
1173         missHitValue  = {{rqValue0, rqValue1}};
1174         hitMissValue  = {{tcu::UVec4(0, 0, 0, 0), tcu::UVec4(0, 0, 0, 0)}};
1175         hitHitValue   = {{rqValue0, rqValue1}};
1176         break;
1177     case SST_INTERSECTION_SHADER:
1178         missMissValue = {{tcu::UVec4(4, 0, 0, 0), tcu::UVec4(0, 0, 0, 0)}};
1179         missHitValue  = {{tcu::UVec4(4, 0, 0, 0), tcu::UVec4(0, 0, 0, 0)}};
1180         hitMissValue  = {{tcu::UVec4(0, 0, 0, 0), tcu::UVec4(0, 0, 0, 0)}};
1181         hitHitValue   = {{rqValue0, rqValue1}};
1182         break;
1183     case SST_ANY_HIT_SHADER:
1184         missMissValue = {{tcu::UVec4(4, 0, 0, 0), tcu::UVec4(0, 0, 0, 0)}};
1185         missHitValue  = {{tcu::UVec4(4, 0, 0, 0), tcu::UVec4(0, 0, 0, 0)}};
1186         hitMissValue  = {{tcu::UVec4(0, 0, 0, 0), tcu::UVec4(0, 0, 0, 0)}};
1187         hitHitValue   = {{rqValue0, rqValue1}};
1188         break;
1189     case SST_CLOSEST_HIT_SHADER:
1190         missMissValue = {{tcu::UVec4(4, 0, 0, 0), tcu::UVec4(0, 0, 0, 0)}};
1191         missHitValue  = {{tcu::UVec4(4, 0, 0, 0), tcu::UVec4(0, 0, 0, 0)}};
1192         hitMissValue  = {{tcu::UVec4(0, 0, 0, 0), tcu::UVec4(0, 0, 0, 0)}};
1193         hitHitValue   = {{rqValue0, rqValue1}};
1194         break;
1195     case SST_MISS_SHADER:
1196         missMissValue = {{tcu::UVec4(0, 0, 0, 0), tcu::UVec4(0, 0, 0, 0)}};
1197         missHitValue  = {{rqValue0, rqValue1}};
1198         hitMissValue  = {{tcu::UVec4(0, 0, 0, 0), tcu::UVec4(3, 0, 0, 0)}};
1199         hitHitValue   = {{tcu::UVec4(0, 0, 0, 0), tcu::UVec4(3, 0, 0, 0)}};
1200         break;
1201     case SST_CALLABLE_SHADER:
1202         missMissValue = {{tcu::UVec4(0, 0, 0, 0), tcu::UVec4(0, 0, 0, 0)}};
1203         missHitValue  = {{rqValue0, rqValue1}};
1204         hitMissValue  = {{tcu::UVec4(0, 0, 0, 0), tcu::UVec4(0, 0, 0, 0)}};
1205         hitHitValue   = {{rqValue0, rqValue1}};
1206         break;
1207     default:
1208         TCU_THROW(InternalError, "Wrong shader source type");
1209     }
1210 
1211     for (uint32_t y = 0; y < testParams.height / 2; ++y)
1212         for (uint32_t x = 0; x < testParams.width; ++x)
1213         {
1214             referenceAccess.setPixel(hitMissValue[0], x, y, 0);
1215             referenceAccess.setPixel(hitMissValue[1], x, y, 1);
1216         }
1217     for (uint32_t y = testParams.height / 2; y < testParams.height; ++y)
1218         for (uint32_t x = 0; x < testParams.width; ++x)
1219         {
1220             referenceAccess.setPixel(missMissValue[0], x, y, 0);
1221             referenceAccess.setPixel(missMissValue[1], x, y, 1);
1222         }
1223 
1224     for (uint32_t y = 1; y < testParams.height / 2; ++y)
1225         for (uint32_t x = 1; x < testParams.width - 1; ++x)
1226         {
1227             referenceAccess.setPixel(hitHitValue[0], x, y, 0);
1228             referenceAccess.setPixel(hitHitValue[1], x, y, 1);
1229         }
1230 
1231     for (uint32_t y = testParams.height / 2; y < testParams.height - 1; ++y)
1232         for (uint32_t x = 1; x < testParams.width - 1; ++x)
1233         {
1234             referenceAccess.setPixel(missHitValue[0], x, y, 0);
1235             referenceAccess.setPixel(missHitValue[1], x, y, 1);
1236         }
1237 
1238     // compare result and reference
1239     return tcu::intThresholdCompare(context.getTestContext().getLog(), "Result comparison", "", referenceAccess,
1240                                     resultAccess, tcu::UVec4(0), tcu::COMPARE_LOG_RESULT);
1241 }
1242 
getResultImageFormat()1243 VkFormat RayTracingConfiguration::getResultImageFormat()
1244 {
1245     return VK_FORMAT_R32_UINT;
1246 }
1247 
getResultImageFormatSize()1248 size_t RayTracingConfiguration::getResultImageFormatSize()
1249 {
1250     return sizeof(uint32_t);
1251 }
1252 
getClearValue()1253 VkClearValue RayTracingConfiguration::getClearValue()
1254 {
1255     return makeClearValueColorU32(0xFF, 0u, 0u, 0u);
1256 }
1257 
1258 class RayQueryTraversalControlTestCase : public TestCase
1259 {
1260 public:
1261     RayQueryTraversalControlTestCase(tcu::TestContext &context, const char *name, const TestParams data);
1262     ~RayQueryTraversalControlTestCase(void);
1263 
1264     virtual void checkSupport(Context &context) const;
1265     virtual void initPrograms(SourceCollections &programCollection) const;
1266     virtual TestInstance *createInstance(Context &context) const;
1267 
1268 private:
1269     TestParams m_data;
1270 };
1271 
1272 class TraversalControlTestInstance : public TestInstance
1273 {
1274 public:
1275     TraversalControlTestInstance(Context &context, const TestParams &data);
1276     ~TraversalControlTestInstance(void);
1277     tcu::TestStatus iterate(void);
1278 
1279 private:
1280     TestParams m_data;
1281 };
1282 
RayQueryTraversalControlTestCase(tcu::TestContext & context,const char * name,const TestParams data)1283 RayQueryTraversalControlTestCase::RayQueryTraversalControlTestCase(tcu::TestContext &context, const char *name,
1284                                                                    const TestParams data)
1285     : vkt::TestCase(context, name)
1286     , m_data(data)
1287 {
1288 }
1289 
~RayQueryTraversalControlTestCase(void)1290 RayQueryTraversalControlTestCase::~RayQueryTraversalControlTestCase(void)
1291 {
1292 }
1293 
checkSupport(Context & context) const1294 void RayQueryTraversalControlTestCase::checkSupport(Context &context) const
1295 {
1296     context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
1297     context.requireDeviceFunctionality("VK_KHR_ray_query");
1298 
1299     const VkPhysicalDeviceRayQueryFeaturesKHR &rayQueryFeaturesKHR = context.getRayQueryFeatures();
1300     if (rayQueryFeaturesKHR.rayQuery == false)
1301         TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceRayQueryFeaturesKHR.rayQuery");
1302 
1303     const VkPhysicalDeviceAccelerationStructureFeaturesKHR &accelerationStructureFeaturesKHR =
1304         context.getAccelerationStructureFeatures();
1305     if (accelerationStructureFeaturesKHR.accelerationStructure == false)
1306         TCU_THROW(TestError,
1307                   "VK_KHR_ray_query requires VkPhysicalDeviceAccelerationStructureFeaturesKHR.accelerationStructure");
1308 
1309     const VkPhysicalDeviceFeatures2 &features2 = context.getDeviceFeatures2();
1310 
1311     if ((m_data.shaderSourceType == SST_TESSELATION_CONTROL_SHADER ||
1312          m_data.shaderSourceType == SST_TESSELATION_EVALUATION_SHADER) &&
1313         features2.features.tessellationShader == false)
1314         TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceFeatures2.tessellationShader");
1315 
1316     if (m_data.shaderSourceType == SST_GEOMETRY_SHADER && features2.features.geometryShader == false)
1317         TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceFeatures2.geometryShader");
1318 
1319     switch (m_data.shaderSourceType)
1320     {
1321     case SST_VERTEX_SHADER:
1322     case SST_TESSELATION_CONTROL_SHADER:
1323     case SST_TESSELATION_EVALUATION_SHADER:
1324     case SST_GEOMETRY_SHADER:
1325         context.requireDeviceCoreFeature(DEVICE_CORE_FEATURE_VERTEX_PIPELINE_STORES_AND_ATOMICS);
1326         break;
1327     default:
1328         break;
1329     }
1330 
1331     if (m_data.shaderSourceType == SST_RAY_GENERATION_SHADER || m_data.shaderSourceType == SST_INTERSECTION_SHADER ||
1332         m_data.shaderSourceType == SST_ANY_HIT_SHADER || m_data.shaderSourceType == SST_CLOSEST_HIT_SHADER ||
1333         m_data.shaderSourceType == SST_MISS_SHADER || m_data.shaderSourceType == SST_CALLABLE_SHADER)
1334     {
1335         context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
1336 
1337         const VkPhysicalDeviceRayTracingPipelineFeaturesKHR &rayTracingPipelineFeaturesKHR =
1338             context.getRayTracingPipelineFeatures();
1339 
1340         if (rayTracingPipelineFeaturesKHR.rayTracingPipeline == false)
1341             TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceRayTracingPipelineFeaturesKHR.rayTracingPipeline");
1342     }
1343 }
1344 
initPrograms(SourceCollections & programCollection) const1345 void RayQueryTraversalControlTestCase::initPrograms(SourceCollections &programCollection) const
1346 {
1347     const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
1348 
1349     // create parts of programs responsible for test execution
1350     std::vector<std::vector<std::string>> rayQueryTest(2);
1351     std::vector<std::vector<std::string>> rayQueryTestName(2);
1352     {
1353         // STT_GENERATE_INTERSECTION for triangles
1354         std::stringstream css;
1355         css << "  float tmin     = 0.0;\n"
1356                "  float tmax     = 1.0;\n"
1357                "  vec3  direct   = vec3(0.0, 0.0, -1.0);\n"
1358                "  rayQueryEXT rq;\n"
1359                "  rayQueryInitializeEXT(rq, rqTopLevelAS, 0, 0xFF, origin, tmin, direct, tmax);\n"
1360                "  if(rayQueryProceedEXT(rq))\n"
1361                "  {\n"
1362                "    if (rayQueryGetIntersectionTypeEXT(rq, false)==gl_RayQueryCandidateIntersectionTriangleEXT)\n"
1363                "    {\n"
1364                "      hitValue.y=1;\n"
1365                "      rayQueryConfirmIntersectionEXT(rq);\n"
1366                "      rayQueryProceedEXT(rq);\n"
1367                "      hitValue.x = rayQueryGetIntersectionTypeEXT(rq, true);\n"
1368                "    }\n"
1369                "  }\n";
1370         rayQueryTest[BTT_TRIANGLES].push_back(css.str());
1371         rayQueryTestName[BTT_TRIANGLES].push_back("rq_gen_triangle");
1372     }
1373     {
1374         // STT_GENERATE_INTERSECTION for AABBs
1375         std::stringstream css;
1376         css << "  float tmin     = 0.0;\n"
1377                "  float tmax     = 1.0;\n"
1378                "  vec3  direct   = vec3(0.0, 0.0, -1.0);\n"
1379                "  rayQueryEXT rq;\n"
1380                "  rayQueryInitializeEXT(rq, rqTopLevelAS, 0, 0xFF, origin, tmin, direct, tmax);\n"
1381                "  if(rayQueryProceedEXT(rq))\n"
1382                "  {\n"
1383                "    if (rayQueryGetIntersectionTypeEXT(rq, false)==gl_RayQueryCandidateIntersectionAABBEXT)\n"
1384                "    {\n"
1385                "      hitValue.y=1;\n"
1386                "      rayQueryGenerateIntersectionEXT(rq, 0.5);\n"
1387                "      rayQueryProceedEXT(rq);\n"
1388                "      hitValue.x = rayQueryGetIntersectionTypeEXT(rq, true);\n"
1389                "    }\n"
1390                "  }\n";
1391         rayQueryTest[BTT_AABBS].push_back(css.str());
1392         rayQueryTestName[BTT_AABBS].push_back("rq_gen_aabb");
1393     }
1394     {
1395         // STT_SKIP_INTERSECTION for triangles
1396         std::stringstream css;
1397         css << "  float tmin     = 0.0;\n"
1398                "  float tmax     = 1.0;\n"
1399                "  vec3  direct   = vec3(0.0, 0.0, -1.0);\n"
1400                "  rayQueryEXT rq;\n"
1401                "  rayQueryInitializeEXT(rq, rqTopLevelAS, 0, 0xFF, origin, tmin, direct, tmax);\n"
1402                "  if(rayQueryProceedEXT(rq))\n"
1403                "  {\n"
1404                "    if (rayQueryGetIntersectionTypeEXT(rq, false)==gl_RayQueryCandidateIntersectionTriangleEXT)\n"
1405                "    {\n"
1406                "      hitValue.y=1;\n"
1407                "      rayQueryProceedEXT(rq);\n"
1408                "      hitValue.x = rayQueryGetIntersectionTypeEXT(rq, true);\n"
1409                "    }\n"
1410                "  }\n";
1411         rayQueryTest[BTT_TRIANGLES].push_back(css.str());
1412         rayQueryTestName[BTT_TRIANGLES].push_back("rq_skip_triangle");
1413     }
1414     {
1415         // STT_SKIP_INTERSECTION for AABBs
1416         std::stringstream css;
1417         css << "  float tmin     = 0.0;\n"
1418                "  float tmax     = 1.0;\n"
1419                "  vec3  direct   = vec3(0.0, 0.0, -1.0);\n"
1420                "  rayQueryEXT rq;\n"
1421                "  rayQueryInitializeEXT(rq, rqTopLevelAS, 0, 0xFF, origin, tmin, direct, tmax);\n"
1422                "  if(rayQueryProceedEXT(rq))\n"
1423                "  {\n"
1424                "    if (rayQueryGetIntersectionTypeEXT(rq, false)==gl_RayQueryCandidateIntersectionAABBEXT)\n"
1425                "    {\n"
1426                "      hitValue.y=1;\n"
1427                "      rayQueryProceedEXT(rq);\n"
1428                "      hitValue.x = rayQueryGetIntersectionTypeEXT(rq, true);\n"
1429                "    }\n"
1430                "  }\n";
1431         rayQueryTest[BTT_AABBS].push_back(css.str());
1432         rayQueryTestName[BTT_AABBS].push_back("rq_skip_aabb");
1433     }
1434 
1435     // create all programs
1436     if (m_data.shaderSourcePipeline == SSP_GRAPHICS_PIPELINE)
1437     {
1438         {
1439             std::stringstream css;
1440             css << "#version 460 core\n"
1441                    "layout (location = 0) in vec3 position;\n"
1442                    "out gl_PerVertex\n"
1443                    "{\n"
1444                    "  vec4 gl_Position;\n"
1445                    "};\n"
1446                    "void main()\n"
1447                    "{\n"
1448                    "  gl_Position = vec4(position, 1.0);\n"
1449                    "}\n";
1450             programCollection.glslSources.add("vert") << glu::VertexSource(css.str()) << buildOptions;
1451         }
1452 
1453         {
1454             std::stringstream css;
1455             css << "#version 460 core\n"
1456                    "#extension GL_EXT_ray_query : require\n"
1457                    "layout (location = 0) in vec3 position;\n"
1458                    "layout(r32ui, set = 0, binding = 0) uniform uimage3D result;\n"
1459                    "layout(set = 0, binding = 1) uniform accelerationStructureEXT rqTopLevelAS;\n"
1460                    "void main()\n"
1461                    "{\n"
1462                    "  vec3  origin   = vec3(float(position.x) + 0.5, float(position.y) + 0.5, 0.5);\n"
1463                    "  uvec4 hitValue = uvec4(0,0,0,0);\n"
1464                 << rayQueryTest[m_data.bottomType][m_data.shaderTestType]
1465                 << "  imageStore(result, ivec3(gl_VertexIndex, 0, 0), uvec4(hitValue.x, 0, 0, 0));\n"
1466                    "  imageStore(result, ivec3(gl_VertexIndex, 0, 1), uvec4(hitValue.y, 0, 0, 0));\n"
1467                    "  gl_Position = vec4(position,1);\n"
1468                    "}\n";
1469             std::stringstream cssName;
1470             cssName << "vert_" << rayQueryTestName[m_data.bottomType][m_data.shaderTestType];
1471 
1472             programCollection.glslSources.add(cssName.str()) << glu::VertexSource(css.str()) << buildOptions;
1473         }
1474 
1475         {
1476             std::stringstream css;
1477             css << "#version 460 core\n"
1478                    "#extension GL_EXT_tessellation_shader : require\n"
1479                    "in gl_PerVertex {\n"
1480                    "  vec4  gl_Position;\n"
1481                    "} gl_in[];\n"
1482                    "layout(vertices = 3) out;\n"
1483                    "void main (void)\n"
1484                    "{\n"
1485                    "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
1486                    "  gl_TessLevelInner[0] = 1;\n"
1487                    "  gl_TessLevelOuter[0] = 1;\n"
1488                    "  gl_TessLevelOuter[1] = 1;\n"
1489                    "  gl_TessLevelOuter[2] = 1;\n"
1490                    "}\n";
1491             programCollection.glslSources.add("tesc") << glu::TessellationControlSource(css.str()) << buildOptions;
1492         }
1493 
1494         {
1495             std::stringstream css;
1496             css << "#version 460 core\n"
1497                    "#extension GL_EXT_tessellation_shader : require\n"
1498                    "#extension GL_EXT_ray_query : require\n"
1499                    "layout(r32ui, set = 0, binding = 0) uniform uimage3D result;\n"
1500                    "layout(set = 0, binding = 1) uniform accelerationStructureEXT rqTopLevelAS;\n"
1501                    "in gl_PerVertex {\n"
1502                    "  vec4  gl_Position;\n"
1503                    "} gl_in[];\n"
1504                    "layout(vertices = 3) out;\n"
1505                    "void main (void)\n"
1506                    "{\n"
1507                    "  vec3  origin   = vec3(gl_in[gl_InvocationID].gl_Position.x + 0.5, "
1508                    "gl_in[gl_InvocationID].gl_Position.y + 0.5, 0.5);\n"
1509                    "  uvec4 hitValue = uvec4(0,0,0,0);\n"
1510                 << rayQueryTest[m_data.bottomType][m_data.shaderTestType]
1511                 << "  imageStore(result, ivec3(gl_PrimitiveID, gl_InvocationID, 0), uvec4(hitValue.x, 0, 0, 0));\n"
1512                    "  imageStore(result, ivec3(gl_PrimitiveID, gl_InvocationID, 1), uvec4(hitValue.y, 0, 0, 0));\n"
1513                    "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
1514                    "  gl_TessLevelInner[0] = 1;\n"
1515                    "  gl_TessLevelOuter[0] = 1;\n"
1516                    "  gl_TessLevelOuter[1] = 1;\n"
1517                    "  gl_TessLevelOuter[2] = 1;\n"
1518                    "}\n";
1519             std::stringstream cssName;
1520             cssName << "tesc_" << rayQueryTestName[m_data.bottomType][m_data.shaderTestType];
1521 
1522             programCollection.glslSources.add(cssName.str())
1523                 << glu::TessellationControlSource(css.str()) << buildOptions;
1524         }
1525 
1526         {
1527             std::stringstream css;
1528             css << "#version 460 core\n"
1529                    "#extension GL_EXT_tessellation_shader : require\n"
1530                    "#extension GL_EXT_ray_query : require\n"
1531                    "layout(triangles, equal_spacing, ccw) in;\n"
1532                    "layout(r32ui, set = 0, binding = 0) uniform uimage3D result;\n"
1533                    "layout(set = 0, binding = 1) uniform accelerationStructureEXT rqTopLevelAS;\n"
1534                    "void main (void)\n"
1535                    "{\n"
1536                    "  for (int i = 0; i < 3; ++i)\n"
1537                    "  {\n"
1538                    "    vec3  origin   = vec3(gl_in[i].gl_Position.x + 0.5, gl_in[i].gl_Position.y + 0.5, 0.5);\n"
1539                    "    uvec4 hitValue = uvec4(0,0,0,0);\n"
1540                 << rayQueryTest[m_data.bottomType][m_data.shaderTestType]
1541                 << "    imageStore(result, ivec3(gl_PrimitiveID, i, 0), uvec4(hitValue.x, 0, 0, 0));\n"
1542                    "    imageStore(result, ivec3(gl_PrimitiveID, i, 1), uvec4(hitValue.y, 0, 0, 0));\n"
1543                    "  }\n"
1544                    "  gl_Position = gl_in[0].gl_Position;\n"
1545                    "}\n";
1546             std::stringstream cssName;
1547             cssName << "tese_" << rayQueryTestName[m_data.bottomType][m_data.shaderTestType];
1548 
1549             programCollection.glslSources.add(cssName.str())
1550                 << glu::TessellationEvaluationSource(css.str()) << buildOptions;
1551         }
1552 
1553         {
1554             std::stringstream css;
1555             css << "#version 460 core\n"
1556                    "#extension GL_EXT_tessellation_shader : require\n"
1557                    "layout(triangles, equal_spacing, ccw) in;\n"
1558                    "void main (void)\n"
1559                    "{\n"
1560                    "  gl_Position = gl_in[0].gl_Position;\n"
1561                    "}\n";
1562 
1563             programCollection.glslSources.add("tese") << glu::TessellationEvaluationSource(css.str()) << buildOptions;
1564         }
1565 
1566         {
1567             std::stringstream css;
1568             css << "#version 460 core\n"
1569                    "#extension GL_EXT_ray_query : require\n"
1570                    "layout(triangles) in;\n"
1571                    "layout (triangle_strip, max_vertices = 4) out;\n"
1572                    "layout(r32ui, set = 0, binding = 0) uniform uimage3D result;\n"
1573                    "layout(set = 0, binding = 1) uniform accelerationStructureEXT rqTopLevelAS;\n"
1574                    "\n"
1575                    "in gl_PerVertex {\n"
1576                    "  vec4  gl_Position;\n"
1577                    "} gl_in[];\n"
1578                    "out gl_PerVertex {\n"
1579                    "  vec4 gl_Position;\n"
1580                    "};\n"
1581                    "void main (void)\n"
1582                    "{\n"
1583                    "  for (int i = 0; i < gl_in.length(); ++i)\n"
1584                    "  {\n"
1585                    "    vec3  origin   = vec3(gl_in[i].gl_Position.x + 0.5, gl_in[i].gl_Position.y + 0.5, 0.5);\n"
1586                    "    uvec4 hitValue = uvec4(0,0,0,0);\n"
1587                 << rayQueryTest[m_data.bottomType][m_data.shaderTestType]
1588                 << "    imageStore(result, ivec3(gl_PrimitiveIDIn, i, 0), uvec4(hitValue.x, 0, 0, 0));\n"
1589                    "    imageStore(result, ivec3(gl_PrimitiveIDIn, i, 1), uvec4(hitValue.y, 0, 0, 0));\n"
1590                    "    gl_Position      = gl_in[i].gl_Position;\n"
1591                    "    EmitVertex();\n"
1592                    "  }\n"
1593                    "  EndPrimitive();\n"
1594                    "}\n";
1595             std::stringstream cssName;
1596             cssName << "geom_" << rayQueryTestName[m_data.bottomType][m_data.shaderTestType];
1597 
1598             programCollection.glslSources.add(cssName.str()) << glu::GeometrySource(css.str()) << buildOptions;
1599         }
1600 
1601         {
1602             std::stringstream css;
1603             css << "#version 460 core\n"
1604                    "#extension GL_EXT_ray_query : require\n"
1605                    "layout(r32ui, set = 0, binding = 0) uniform uimage3D result;\n"
1606                    "layout(set = 0, binding = 1) uniform accelerationStructureEXT rqTopLevelAS;\n"
1607                    "void main()\n"
1608                    "{\n"
1609                    "  vec3  origin   = vec3(gl_FragCoord.x, gl_FragCoord.y, 0.5);\n"
1610                    "  uvec4 hitValue = uvec4(0,0,0,0);\n"
1611                 << rayQueryTest[m_data.bottomType][m_data.shaderTestType]
1612                 << "  imageStore(result, ivec3(gl_FragCoord.xy-vec2(0.5,0.5), 0), uvec4(hitValue.x, 0, 0, 0));\n"
1613                    "  imageStore(result, ivec3(gl_FragCoord.xy-vec2(0.5,0.5), 1), uvec4(hitValue.y, 0, 0, 0));\n"
1614                    "}\n";
1615             std::stringstream cssName;
1616             cssName << "frag_" << rayQueryTestName[m_data.bottomType][m_data.shaderTestType];
1617 
1618             programCollection.glslSources.add(cssName.str()) << glu::FragmentSource(css.str()) << buildOptions;
1619         }
1620     }
1621     else if (m_data.shaderSourcePipeline == SSP_COMPUTE_PIPELINE)
1622     {
1623         {
1624             std::stringstream css;
1625             css << "#version 460 core\n"
1626                    "#extension GL_EXT_ray_query : require\n"
1627                    "layout(r32ui, set = 0, binding = 0) uniform uimage3D result;\n"
1628                    "layout(set = 0, binding = 1) uniform accelerationStructureEXT rqTopLevelAS;\n"
1629                    "void main()\n"
1630                    "{\n"
1631                    "  vec3  origin   = vec3(float(gl_GlobalInvocationID.x) + 0.5, float(gl_GlobalInvocationID.y) + "
1632                    "0.5, 0.5);\n"
1633                    "  uvec4 hitValue = uvec4(0,0,0,0);\n"
1634                 << rayQueryTest[m_data.bottomType][m_data.shaderTestType]
1635                 << "  imageStore(result, ivec3(gl_GlobalInvocationID.xy, 0), uvec4(hitValue.x, 0, 0, 0));\n"
1636                    "  imageStore(result, ivec3(gl_GlobalInvocationID.xy, 1), uvec4(hitValue.y, 0, 0, 0));\n"
1637                    "}\n";
1638             std::stringstream cssName;
1639             cssName << "comp_" << rayQueryTestName[m_data.bottomType][m_data.shaderTestType];
1640 
1641             programCollection.glslSources.add(cssName.str()) << glu::ComputeSource(css.str()) << buildOptions;
1642         }
1643     }
1644     else if (m_data.shaderSourcePipeline == SSP_RAY_TRACING_PIPELINE)
1645     {
1646         {
1647             std::stringstream css;
1648             css << "#version 460 core\n"
1649                    "#extension GL_EXT_ray_tracing : require\n"
1650                    "layout(location = 0) rayPayloadEXT uvec4 hitValue;\n"
1651                    "layout(r32ui, set = 0, binding = 0) uniform uimage3D result;\n"
1652                    "layout(set = 0, binding = 1) uniform accelerationStructureEXT topLevelAS;\n"
1653                    "void main()\n"
1654                    "{\n"
1655                    "  float tmin     = 0.0;\n"
1656                    "  float tmax     = 1.0;\n"
1657                    "  vec3  origin   = vec3(float(gl_LaunchIDEXT.x) + 0.5, float(gl_LaunchIDEXT.y) + 0.5, 0.5);\n"
1658                    "  vec3  direct   = vec3(0.0, 0.0, -1.0);\n"
1659                    "  hitValue       = uvec4(0,0,0,0);\n"
1660                    "  traceRayEXT(topLevelAS, 0, 0xFF, 0, 0, 0, origin, tmin, direct, tmax, 0);\n"
1661                    "  imageStore(result, ivec3(gl_LaunchIDEXT.xy, 0), uvec4(hitValue.x, 0, 0, 0));\n"
1662                    "  imageStore(result, ivec3(gl_LaunchIDEXT.xy, 1), uvec4(hitValue.y, 0, 0, 0));\n"
1663                    "}\n";
1664             programCollection.glslSources.add("rgen")
1665                 << glu::RaygenSource(updateRayTracingGLSL(css.str())) << buildOptions;
1666         }
1667 
1668         {
1669             std::stringstream css;
1670             css << "#version 460 core\n"
1671                    "#extension GL_EXT_ray_tracing : require\n"
1672                    "#extension GL_EXT_ray_query : require\n"
1673                    "layout(r32ui, set = 0, binding = 0) uniform uimage3D result;\n"
1674                    "layout(set = 0, binding = 1) uniform accelerationStructureEXT topLevelAS;\n"
1675                    "layout(set = 0, binding = 2) uniform accelerationStructureEXT rqTopLevelAS;\n"
1676                    "void main()\n"
1677                    "{\n"
1678                    "  vec3  origin    = vec3(float(gl_LaunchIDEXT.x) + 0.5, float(gl_LaunchIDEXT.y) + 0.5, 0.5);\n"
1679                    "  uvec4  hitValue = uvec4(0,0,0,0);\n"
1680                 << rayQueryTest[m_data.bottomType][m_data.shaderTestType]
1681                 << "  imageStore(result, ivec3(gl_LaunchIDEXT.xy, 0), uvec4(hitValue.x, 0, 0, 0));\n"
1682                    "  imageStore(result, ivec3(gl_LaunchIDEXT.xy, 1), uvec4(hitValue.y, 0, 0, 0));\n"
1683                    "}\n";
1684             std::stringstream cssName;
1685             cssName << "rgen_" << rayQueryTestName[m_data.bottomType][m_data.shaderTestType];
1686 
1687             programCollection.glslSources.add(cssName.str())
1688                 << glu::RaygenSource(updateRayTracingGLSL(css.str())) << buildOptions;
1689         }
1690 
1691         {
1692             std::stringstream css;
1693             css << "#version 460 core\n"
1694                    "#extension GL_EXT_ray_tracing : require\n"
1695                    "struct CallValue\n{\n"
1696                    "  vec3  origin;\n"
1697                    "  uvec4 hitValue;\n"
1698                    "};\n"
1699                    "layout(location = 0) callableDataEXT CallValue param;\n"
1700                    "layout(r32ui, set = 0, binding = 0) uniform uimage3D result;\n"
1701                    "layout(set = 0, binding = 1) uniform accelerationStructureEXT topLevelAS;\n"
1702                    "void main()\n"
1703                    "{\n"
1704                    "  param.origin   = vec3(float(gl_LaunchIDEXT.x) + 0.5, float(gl_LaunchIDEXT.y) + 0.5, 0.5);\n"
1705                    "  param.hitValue = uvec4(0, 0, 0, 0);\n"
1706                    "  executeCallableEXT(0, 0);\n"
1707                    "  imageStore(result, ivec3(gl_LaunchIDEXT.xy, 0), uvec4(param.hitValue.x, 0, 0, 0));\n"
1708                    "  imageStore(result, ivec3(gl_LaunchIDEXT.xy, 1), uvec4(param.hitValue.y, 0, 0, 0));\n"
1709                    "}\n";
1710             programCollection.glslSources.add("rgen_call")
1711                 << glu::RaygenSource(updateRayTracingGLSL(css.str())) << buildOptions;
1712         }
1713 
1714         {
1715             std::stringstream css;
1716             css << "#version 460 core\n"
1717                    "#extension GL_EXT_ray_tracing : require\n"
1718                    "hitAttributeEXT uvec4 hitValue;\n"
1719                    "void main()\n"
1720                    "{\n"
1721                    "  reportIntersectionEXT(0.5f, 0);\n"
1722                    "}\n";
1723 
1724             programCollection.glslSources.add("isect")
1725                 << glu::IntersectionSource(updateRayTracingGLSL(css.str())) << buildOptions;
1726         }
1727 
1728         {
1729             std::stringstream css;
1730             css << "#version 460 core\n"
1731                    "#extension GL_EXT_ray_tracing : require\n"
1732                    "#extension GL_EXT_ray_query : require\n"
1733                    "hitAttributeEXT uvec4 hitValue;\n"
1734                    "layout(set = 0, binding = 2) uniform accelerationStructureEXT rqTopLevelAS;\n"
1735                    "void main()\n"
1736                    "{\n"
1737                    "  vec3 origin = gl_WorldRayOriginEXT;\n"
1738                    "  hitValue    = uvec4(0,0,0,0);\n"
1739                 << rayQueryTest[m_data.bottomType][m_data.shaderTestType]
1740                 << "  reportIntersectionEXT(0.5f, 0);\n"
1741                    "}\n";
1742             std::stringstream cssName;
1743             cssName << "isect_" << rayQueryTestName[m_data.bottomType][m_data.shaderTestType];
1744 
1745             programCollection.glslSources.add(cssName.str())
1746                 << glu::IntersectionSource(updateRayTracingGLSL(css.str())) << buildOptions;
1747         }
1748 
1749         {
1750             std::stringstream css;
1751             css << "#version 460 core\n"
1752                    "#extension GL_EXT_ray_tracing : require\n"
1753                    "#extension GL_EXT_ray_query : require\n"
1754                    "layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
1755                    "layout(set = 0, binding = 2) uniform accelerationStructureEXT rqTopLevelAS;\n"
1756                    "void main()\n"
1757                    "{\n"
1758                    "  vec3 origin = gl_WorldRayOriginEXT;\n"
1759                 << rayQueryTest[m_data.bottomType][m_data.shaderTestType] << "}\n";
1760             std::stringstream cssName;
1761             cssName << "ahit_" << rayQueryTestName[m_data.bottomType][m_data.shaderTestType];
1762 
1763             programCollection.glslSources.add(cssName.str())
1764                 << glu::AnyHitSource(updateRayTracingGLSL(css.str())) << buildOptions;
1765         }
1766 
1767         {
1768             std::stringstream css;
1769             css << "#version 460 core\n"
1770                    "#extension GL_EXT_ray_tracing : require\n"
1771                    "layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
1772                    "void main()\n"
1773                    "{\n"
1774                    "  hitValue.y = 3;\n"
1775                    "}\n";
1776 
1777             programCollection.glslSources.add("chit")
1778                 << glu::ClosestHitSource(updateRayTracingGLSL(css.str())) << buildOptions;
1779         }
1780 
1781         {
1782             std::stringstream css;
1783             css << "#version 460 core\n"
1784                    "#extension GL_EXT_ray_tracing : require\n"
1785                    "#extension GL_EXT_ray_query : require\n"
1786                    "layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
1787                    "layout(set = 0, binding = 2) uniform accelerationStructureEXT rqTopLevelAS;\n"
1788                    "void main()\n"
1789                    "{\n"
1790                    "  vec3 origin = gl_WorldRayOriginEXT;\n"
1791                 << rayQueryTest[m_data.bottomType][m_data.shaderTestType] << "}\n";
1792             std::stringstream cssName;
1793             cssName << "chit_" << rayQueryTestName[m_data.bottomType][m_data.shaderTestType];
1794 
1795             programCollection.glslSources.add(cssName.str())
1796                 << glu::ClosestHitSource(updateRayTracingGLSL(css.str())) << buildOptions;
1797         }
1798 
1799         {
1800             std::stringstream css;
1801             css << "#version 460 core\n"
1802                    "#extension GL_EXT_ray_tracing : require\n"
1803                    "layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
1804                    "hitAttributeEXT uvec4 hitAttrib;\n"
1805                    "void main()\n"
1806                    "{\n"
1807                    "  hitValue = hitAttrib;\n"
1808                    "}\n";
1809 
1810             programCollection.glslSources.add("chit_isect")
1811                 << glu::ClosestHitSource(updateRayTracingGLSL(css.str())) << buildOptions;
1812         }
1813 
1814         {
1815             std::stringstream css;
1816             css << "#version 460 core\n"
1817                    "#extension GL_EXT_ray_tracing : require\n"
1818                    "layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
1819                    "void main()\n"
1820                    "{\n"
1821                    "  hitValue.x = 4;\n"
1822                    "}\n";
1823 
1824             programCollection.glslSources.add("miss")
1825                 << glu::MissSource(updateRayTracingGLSL(css.str())) << buildOptions;
1826         }
1827 
1828         {
1829             std::stringstream css;
1830             css << "#version 460 core\n"
1831                    "#extension GL_EXT_ray_tracing : require\n"
1832                    "#extension GL_EXT_ray_query : require\n"
1833                    "layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
1834                    "layout(set = 0, binding = 2) uniform accelerationStructureEXT rqTopLevelAS;\n"
1835                    "void main()\n"
1836                    "{\n"
1837                    "  vec3 origin = gl_WorldRayOriginEXT;\n"
1838                 << rayQueryTest[m_data.bottomType][m_data.shaderTestType] << "}\n";
1839             std::stringstream cssName;
1840             cssName << "miss_" << rayQueryTestName[m_data.bottomType][m_data.shaderTestType];
1841 
1842             programCollection.glslSources.add(cssName.str())
1843                 << glu::MissSource(updateRayTracingGLSL(css.str())) << buildOptions;
1844         }
1845 
1846         {
1847             std::stringstream css;
1848             css << "#version 460 core\n"
1849                    "#extension GL_EXT_ray_tracing : require\n"
1850                    "#extension GL_EXT_ray_query : require\n"
1851                    "struct CallValue\n{\n"
1852                    "  vec3  origin;\n"
1853                    "  uvec4 hitValue;\n"
1854                    "};\n"
1855                    "layout(location = 0) callableDataInEXT CallValue result;\n"
1856                    "layout(set = 0, binding = 2) uniform accelerationStructureEXT rqTopLevelAS;\n"
1857                    "void main()\n"
1858                    "{\n"
1859                    "  vec3 origin    = result.origin;\n"
1860                    "  uvec4 hitValue = uvec4(0,0,0,0);\n"
1861                 << rayQueryTest[m_data.bottomType][m_data.shaderTestType]
1862                 << "  result.hitValue = hitValue;\n"
1863                    "}\n";
1864             std::stringstream cssName;
1865             cssName << "call_" << rayQueryTestName[m_data.bottomType][m_data.shaderTestType];
1866 
1867             programCollection.glslSources.add(cssName.str())
1868                 << glu::CallableSource(updateRayTracingGLSL(css.str())) << buildOptions;
1869         }
1870     }
1871 }
1872 
createInstance(Context & context) const1873 TestInstance *RayQueryTraversalControlTestCase::createInstance(Context &context) const
1874 {
1875     return new TraversalControlTestInstance(context, m_data);
1876 }
1877 
TraversalControlTestInstance(Context & context,const TestParams & data)1878 TraversalControlTestInstance::TraversalControlTestInstance(Context &context, const TestParams &data)
1879     : vkt::TestInstance(context)
1880     , m_data(data)
1881 {
1882 }
1883 
~TraversalControlTestInstance(void)1884 TraversalControlTestInstance::~TraversalControlTestInstance(void)
1885 {
1886 }
1887 
iterate(void)1888 tcu::TestStatus TraversalControlTestInstance::iterate(void)
1889 {
1890     de::SharedPtr<TestConfiguration> testConfiguration;
1891 
1892     switch (m_data.shaderSourcePipeline)
1893     {
1894     case SSP_GRAPHICS_PIPELINE:
1895         testConfiguration = de::SharedPtr<TestConfiguration>(new GraphicsConfiguration());
1896         break;
1897     case SSP_COMPUTE_PIPELINE:
1898         testConfiguration = de::SharedPtr<TestConfiguration>(new ComputeConfiguration());
1899         break;
1900     case SSP_RAY_TRACING_PIPELINE:
1901         testConfiguration = de::SharedPtr<TestConfiguration>(new RayTracingConfiguration());
1902         break;
1903     default:
1904         TCU_THROW(InternalError, "Wrong shader source pipeline");
1905     }
1906 
1907     testConfiguration->initConfiguration(m_context, m_data);
1908 
1909     const DeviceInterface &vkd      = m_context.getDeviceInterface();
1910     const VkDevice device           = m_context.getDevice();
1911     const VkQueue queue             = m_context.getUniversalQueue();
1912     Allocator &allocator            = m_context.getDefaultAllocator();
1913     const uint32_t queueFamilyIndex = m_context.getUniversalQueueFamilyIndex();
1914 
1915     const VkFormat imageFormat              = testConfiguration->getResultImageFormat();
1916     const VkImageCreateInfo imageCreateInfo = makeImageCreateInfo(m_data.width, m_data.height, 2, imageFormat);
1917     const VkImageSubresourceRange imageSubresourceRange =
1918         makeImageSubresourceRange(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 1u, 0u, 1u);
1919     const de::MovePtr<ImageWithMemory> image = de::MovePtr<ImageWithMemory>(
1920         new ImageWithMemory(vkd, device, allocator, imageCreateInfo, MemoryRequirement::Any));
1921     const Move<VkImageView> imageView =
1922         makeImageView(vkd, device, **image, VK_IMAGE_VIEW_TYPE_3D, imageFormat, imageSubresourceRange);
1923 
1924     const VkBufferCreateInfo resultBufferCreateInfo =
1925         makeBufferCreateInfo(m_data.width * m_data.height * 2 * testConfiguration->getResultImageFormatSize(),
1926                              VK_BUFFER_USAGE_TRANSFER_DST_BIT);
1927     const VkImageSubresourceLayers resultBufferImageSubresourceLayers =
1928         makeImageSubresourceLayers(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 0u, 1u);
1929     const VkBufferImageCopy resultBufferImageRegion =
1930         makeBufferImageCopy(makeExtent3D(m_data.width, m_data.height, 2), resultBufferImageSubresourceLayers);
1931     de::MovePtr<BufferWithMemory> resultBuffer = de::MovePtr<BufferWithMemory>(
1932         new BufferWithMemory(vkd, device, allocator, resultBufferCreateInfo, MemoryRequirement::HostVisible));
1933 
1934     const VkDescriptorImageInfo resultImageInfo = makeDescriptorImageInfo(DE_NULL, *imageView, VK_IMAGE_LAYOUT_GENERAL);
1935 
1936     const Move<VkCommandPool> cmdPool = createCommandPool(vkd, device, 0, queueFamilyIndex);
1937     const Move<VkCommandBuffer> cmdBuffer =
1938         allocateCommandBuffer(vkd, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
1939 
1940     std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> rayQueryBottomLevelAccelerationStructures;
1941     de::MovePtr<TopLevelAccelerationStructure> rayQueryTopLevelAccelerationStructure;
1942 
1943     beginCommandBuffer(vkd, *cmdBuffer, 0u);
1944     {
1945         const VkImageMemoryBarrier preImageBarrier =
1946             makeImageMemoryBarrier(0u, VK_ACCESS_TRANSFER_WRITE_BIT, VK_IMAGE_LAYOUT_UNDEFINED,
1947                                    VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, **image, imageSubresourceRange);
1948         cmdPipelineImageMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT,
1949                                       VK_PIPELINE_STAGE_TRANSFER_BIT, &preImageBarrier);
1950 
1951         const VkClearValue clearValue = testConfiguration->getClearValue();
1952         vkd.cmdClearColorImage(*cmdBuffer, **image, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, &clearValue.color, 1,
1953                                &imageSubresourceRange);
1954 
1955         const VkImageMemoryBarrier postImageBarrier = makeImageMemoryBarrier(
1956             VK_ACCESS_TRANSFER_WRITE_BIT,
1957             VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR,
1958             VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, VK_IMAGE_LAYOUT_GENERAL, **image, imageSubresourceRange);
1959         cmdPipelineImageMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT,
1960                                       VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR, &postImageBarrier);
1961 
1962         // build acceleration structures for ray query
1963         {
1964             de::MovePtr<BottomLevelAccelerationStructure> bottomLevelAccelerationStructure =
1965                 makeBottomLevelAccelerationStructure();
1966             bottomLevelAccelerationStructure->setGeometryCount(1);
1967 
1968             de::SharedPtr<RaytracedGeometryBase> geometry;
1969             if (m_data.bottomType == BTT_TRIANGLES)
1970             {
1971                 tcu::Vec3 v0(1.0f, float(m_data.height) - 1.0f, 0.0f);
1972                 tcu::Vec3 v1(1.0f, 1.0f, 0.0f);
1973                 tcu::Vec3 v2(float(m_data.width) - 1.0f, float(m_data.height) - 1.0f, 0.0f);
1974                 tcu::Vec3 v3(float(m_data.width) - 1.0f, 1.0f, 0.0f);
1975 
1976                 geometry = makeRaytracedGeometry(VK_GEOMETRY_TYPE_TRIANGLES_KHR, VK_FORMAT_R32G32B32_SFLOAT,
1977                                                  VK_INDEX_TYPE_NONE_KHR);
1978                 geometry->addVertex(v0);
1979                 geometry->addVertex(v1);
1980                 geometry->addVertex(v2);
1981                 geometry->addVertex(v2);
1982                 geometry->addVertex(v1);
1983                 geometry->addVertex(v3);
1984             }
1985             else // testParams.bottomType != BTT_TRIANGLES
1986             {
1987                 tcu::Vec3 v0(1.0f, 1.0f, -0.1f);
1988                 tcu::Vec3 v1(float(m_data.width) - 1.0f, float(m_data.height) - 1.0f, 0.1f);
1989 
1990                 geometry = makeRaytracedGeometry(VK_GEOMETRY_TYPE_AABBS_KHR, VK_FORMAT_R32G32B32_SFLOAT,
1991                                                  VK_INDEX_TYPE_NONE_KHR);
1992                 geometry->addVertex(v0);
1993                 geometry->addVertex(v1);
1994             }
1995             bottomLevelAccelerationStructure->addGeometry(geometry);
1996             rayQueryBottomLevelAccelerationStructures.push_back(
1997                 de::SharedPtr<BottomLevelAccelerationStructure>(bottomLevelAccelerationStructure.release()));
1998         }
1999 
2000         for (auto &blas : rayQueryBottomLevelAccelerationStructures)
2001             blas->createAndBuild(vkd, device, *cmdBuffer, allocator);
2002 
2003         rayQueryTopLevelAccelerationStructure = makeTopLevelAccelerationStructure();
2004         rayQueryTopLevelAccelerationStructure->setInstanceCount(1);
2005         rayQueryTopLevelAccelerationStructure->addInstance(rayQueryBottomLevelAccelerationStructures[0]);
2006         rayQueryTopLevelAccelerationStructure->createAndBuild(vkd, device, *cmdBuffer, allocator);
2007 
2008         const TopLevelAccelerationStructure *rayQueryTopLevelAccelerationStructurePtr =
2009             rayQueryTopLevelAccelerationStructure.get();
2010         VkWriteDescriptorSetAccelerationStructureKHR accelerationStructureWriteDescriptorSet = {
2011             VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR, //  VkStructureType sType;
2012             DE_NULL,                                                           //  const void* pNext;
2013             1u,                                                                //  uint32_t accelerationStructureCount;
2014             rayQueryTopLevelAccelerationStructurePtr
2015                 ->getPtr(), //  const VkAccelerationStructureKHR* pAccelerationStructures;
2016         };
2017 
2018         testConfiguration->fillCommandBuffer(m_context, m_data, *cmdBuffer, accelerationStructureWriteDescriptorSet,
2019                                              resultImageInfo);
2020 
2021         const VkMemoryBarrier postTestMemoryBarrier =
2022             makeMemoryBarrier(VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_TRANSFER_READ_BIT);
2023         const VkMemoryBarrier postCopyMemoryBarrier =
2024             makeMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
2025         cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT,
2026                                  &postTestMemoryBarrier);
2027 
2028         vkd.cmdCopyImageToBuffer(*cmdBuffer, **image, VK_IMAGE_LAYOUT_GENERAL, **resultBuffer, 1u,
2029                                  &resultBufferImageRegion);
2030 
2031         cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_HOST_BIT,
2032                                  &postCopyMemoryBarrier);
2033     }
2034     endCommandBuffer(vkd, *cmdBuffer);
2035 
2036     submitCommandsAndWait(vkd, device, queue, cmdBuffer.get());
2037 
2038     invalidateMappedMemoryRange(vkd, device, resultBuffer->getAllocation().getMemory(),
2039                                 resultBuffer->getAllocation().getOffset(), VK_WHOLE_SIZE);
2040 
2041     bool result = testConfiguration->verifyImage(resultBuffer.get(), m_context, m_data);
2042 
2043     rayQueryTopLevelAccelerationStructure.clear();
2044     rayQueryBottomLevelAccelerationStructures.clear();
2045     testConfiguration.clear();
2046 
2047     if (!result)
2048         return tcu::TestStatus::fail("Fail");
2049     return tcu::TestStatus::pass("Pass");
2050 }
2051 
2052 } // namespace
2053 
createTraversalControlTests(tcu::TestContext & testCtx)2054 tcu::TestCaseGroup *createTraversalControlTests(tcu::TestContext &testCtx)
2055 {
2056     // Tests verifying traversal control in RT hit shaders
2057     de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "traversal_control"));
2058 
2059     struct ShaderSourceTypeData
2060     {
2061         ShaderSourceType shaderSourceType;
2062         ShaderSourcePipeline shaderSourcePipeline;
2063         const char *name;
2064     } shaderSourceTypes[] = {
2065         {SST_VERTEX_SHADER, SSP_GRAPHICS_PIPELINE, "vertex_shader"},
2066         {SST_TESSELATION_CONTROL_SHADER, SSP_GRAPHICS_PIPELINE, "tess_control_shader"},
2067         {SST_TESSELATION_EVALUATION_SHADER, SSP_GRAPHICS_PIPELINE, "tess_evaluation_shader"},
2068         {
2069             SST_GEOMETRY_SHADER,
2070             SSP_GRAPHICS_PIPELINE,
2071             "geometry_shader",
2072         },
2073         {
2074             SST_FRAGMENT_SHADER,
2075             SSP_GRAPHICS_PIPELINE,
2076             "fragment_shader",
2077         },
2078         {
2079             SST_COMPUTE_SHADER,
2080             SSP_COMPUTE_PIPELINE,
2081             "compute_shader",
2082         },
2083         {
2084             SST_RAY_GENERATION_SHADER,
2085             SSP_RAY_TRACING_PIPELINE,
2086             "rgen_shader",
2087         },
2088         {
2089             SST_INTERSECTION_SHADER,
2090             SSP_RAY_TRACING_PIPELINE,
2091             "isect_shader",
2092         },
2093         {
2094             SST_ANY_HIT_SHADER,
2095             SSP_RAY_TRACING_PIPELINE,
2096             "ahit_shader",
2097         },
2098         {
2099             SST_CLOSEST_HIT_SHADER,
2100             SSP_RAY_TRACING_PIPELINE,
2101             "chit_shader",
2102         },
2103         {
2104             SST_MISS_SHADER,
2105             SSP_RAY_TRACING_PIPELINE,
2106             "miss_shader",
2107         },
2108         {
2109             SST_CALLABLE_SHADER,
2110             SSP_RAY_TRACING_PIPELINE,
2111             "call_shader",
2112         },
2113     };
2114 
2115     struct ShaderTestTypeData
2116     {
2117         ShaderTestType shaderTestType;
2118         const char *name;
2119     } shaderTestTypes[] = {
2120         {STT_GENERATE_INTERSECTION, "generate_intersection"},
2121         {STT_SKIP_INTERSECTION, "skip_intersection"},
2122     };
2123 
2124     struct
2125     {
2126         BottomTestType testType;
2127         const char *name;
2128     } bottomTestTypes[] = {
2129         {BTT_TRIANGLES, "triangles"},
2130         {BTT_AABBS, "aabbs"},
2131     };
2132 
2133     for (size_t shaderSourceNdx = 0; shaderSourceNdx < DE_LENGTH_OF_ARRAY(shaderSourceTypes); ++shaderSourceNdx)
2134     {
2135         de::MovePtr<tcu::TestCaseGroup> sourceTypeGroup(
2136             new tcu::TestCaseGroup(group->getTestContext(), shaderSourceTypes[shaderSourceNdx].name));
2137 
2138         for (size_t shaderTestNdx = 0; shaderTestNdx < DE_LENGTH_OF_ARRAY(shaderTestTypes); ++shaderTestNdx)
2139         {
2140             de::MovePtr<tcu::TestCaseGroup> testTypeGroup(
2141                 new tcu::TestCaseGroup(group->getTestContext(), shaderTestTypes[shaderTestNdx].name));
2142 
2143             for (size_t testTypeNdx = 0; testTypeNdx < DE_LENGTH_OF_ARRAY(bottomTestTypes); ++testTypeNdx)
2144             {
2145                 TestParams testParams{TEST_WIDTH,
2146                                       TEST_HEIGHT,
2147                                       shaderSourceTypes[shaderSourceNdx].shaderSourceType,
2148                                       shaderSourceTypes[shaderSourceNdx].shaderSourcePipeline,
2149                                       shaderTestTypes[shaderTestNdx].shaderTestType,
2150                                       bottomTestTypes[testTypeNdx].testType};
2151                 testTypeGroup->addChild(new RayQueryTraversalControlTestCase(
2152                     group->getTestContext(), bottomTestTypes[testTypeNdx].name, testParams));
2153             }
2154             sourceTypeGroup->addChild(testTypeGroup.release());
2155         }
2156         group->addChild(sourceTypeGroup.release());
2157     }
2158 
2159     return group.release();
2160 }
2161 
2162 } // namespace RayQuery
2163 
2164 } // namespace vkt
2165