1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2019 The Khronos Group Inc.
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  *      http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  *
19  *//*!
20  * \file
21  * \brief Ray Tracing Complex Control Flow tests
22  *//*--------------------------------------------------------------------*/
23 
24 #include "vktRayTracingComplexControlFlowTests.hpp"
25 
26 #include "vkDefs.hpp"
27 
28 #include "vktTestCase.hpp"
29 #include "vkCmdUtil.hpp"
30 #include "vkObjUtil.hpp"
31 #include "vkBuilderUtil.hpp"
32 #include "vkBarrierUtil.hpp"
33 #include "vkBufferWithMemory.hpp"
34 #include "vkImageWithMemory.hpp"
35 #include "vkTypeUtil.hpp"
36 
37 #include "vkRayTracingUtil.hpp"
38 
39 #include "tcuTestLog.hpp"
40 
41 #include "deRandom.hpp"
42 
43 namespace vkt
44 {
45 namespace RayTracing
46 {
47 namespace
48 {
49 using namespace vk;
50 using namespace std;
51 
52 static const VkFlags ALL_RAY_TRACING_STAGES = VK_SHADER_STAGE_RAYGEN_BIT_KHR | VK_SHADER_STAGE_ANY_HIT_BIT_KHR |
53                                               VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_MISS_BIT_KHR |
54                                               VK_SHADER_STAGE_INTERSECTION_BIT_KHR | VK_SHADER_STAGE_CALLABLE_BIT_KHR;
55 
56 #if defined(DE_DEBUG)
57 static const uint32_t PUSH_CONSTANTS_COUNT = 6;
58 #endif
59 static const uint32_t DEFAULT_CLEAR_VALUE = 999999;
60 
61 enum TestType
62 {
63     TEST_TYPE_IF = 0,
64     TEST_TYPE_LOOP,
65     TEST_TYPE_SWITCH,
66     TEST_TYPE_LOOP_DOUBLE_CALL,
67     TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE,
68     TEST_TYPE_NESTED_LOOP,
69     TEST_TYPE_NESTED_LOOP_BEFORE,
70     TEST_TYPE_NESTED_LOOP_AFTER,
71     TEST_TYPE_FUNCTION_CALL,
72     TEST_TYPE_NESTED_FUNCTION_CALL,
73 };
74 
75 enum TestOp
76 {
77     TEST_OP_EXECUTE_CALLABLE = 0,
78     TEST_OP_TRACE_RAY,
79     TEST_OP_REPORT_INTERSECTION,
80 };
81 
82 enum ShaderGroups
83 {
84     FIRST_GROUP  = 0,
85     RAYGEN_GROUP = FIRST_GROUP,
86     MISS_GROUP,
87     HIT_GROUP,
88     GROUP_COUNT
89 };
90 
91 struct CaseDef
92 {
93     TestType testType;
94     TestOp testOp;
95     VkShaderStageFlagBits stage;
96     uint32_t width;
97     uint32_t height;
98 };
99 
100 struct PushConstants
101 {
102     uint32_t a;
103     uint32_t b;
104     uint32_t c;
105     uint32_t d;
106     uint32_t hitOfs;
107     uint32_t miss;
108 };
109 
getShaderGroupSize(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)110 uint32_t getShaderGroupSize(const InstanceInterface &vki, const VkPhysicalDevice physicalDevice)
111 {
112     de::MovePtr<RayTracingProperties> rayTracingPropertiesKHR;
113 
114     rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
115     return rayTracingPropertiesKHR->getShaderGroupHandleSize();
116 }
117 
getShaderGroupBaseAlignment(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)118 uint32_t getShaderGroupBaseAlignment(const InstanceInterface &vki, const VkPhysicalDevice physicalDevice)
119 {
120     de::MovePtr<RayTracingProperties> rayTracingPropertiesKHR;
121 
122     rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
123     return rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
124 }
125 
makeImageCreateInfo(uint32_t width,uint32_t height,uint32_t depth,VkFormat format)126 VkImageCreateInfo makeImageCreateInfo(uint32_t width, uint32_t height, uint32_t depth, VkFormat format)
127 {
128     const VkImageUsageFlags usage =
129         VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT;
130     const VkImageCreateInfo imageCreateInfo = {
131         VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // VkStructureType sType;
132         DE_NULL,                             // const void* pNext;
133         (VkImageCreateFlags)0u,              // VkImageCreateFlags flags;
134         VK_IMAGE_TYPE_3D,                    // VkImageType imageType;
135         format,                              // VkFormat format;
136         makeExtent3D(width, height, depth),  // VkExtent3D extent;
137         1u,                                  // uint32_t mipLevels;
138         1u,                                  // uint32_t arrayLayers;
139         VK_SAMPLE_COUNT_1_BIT,               // VkSampleCountFlagBits samples;
140         VK_IMAGE_TILING_OPTIMAL,             // VkImageTiling tiling;
141         usage,                               // VkImageUsageFlags usage;
142         VK_SHARING_MODE_EXCLUSIVE,           // VkSharingMode sharingMode;
143         0u,                                  // uint32_t queueFamilyIndexCount;
144         DE_NULL,                             // const uint32_t* pQueueFamilyIndices;
145         VK_IMAGE_LAYOUT_UNDEFINED            // VkImageLayout initialLayout;
146     };
147 
148     return imageCreateInfo;
149 }
150 
makePipelineLayout(const DeviceInterface & vk,const VkDevice device,const VkDescriptorSetLayout descriptorSetLayout,const uint32_t pushConstantsSize)151 Move<VkPipelineLayout> makePipelineLayout(const DeviceInterface &vk, const VkDevice device,
152                                           const VkDescriptorSetLayout descriptorSetLayout,
153                                           const uint32_t pushConstantsSize)
154 {
155     const VkDescriptorSetLayout *descriptorSetLayoutPtr =
156         (descriptorSetLayout == DE_NULL) ? DE_NULL : &descriptorSetLayout;
157     const uint32_t setLayoutCount               = (descriptorSetLayout == DE_NULL) ? 0u : 1u;
158     const VkPushConstantRange pushConstantRange = {
159         ALL_RAY_TRACING_STAGES, //  VkShaderStageFlags stageFlags;
160         0u,                     //  uint32_t offset;
161         pushConstantsSize,      //  uint32_t size;
162     };
163     const VkPushConstantRange *pPushConstantRanges        = (pushConstantsSize == 0) ? DE_NULL : &pushConstantRange;
164     const uint32_t pushConstantRangeCount                 = (pushConstantsSize == 0) ? 0 : 1u;
165     const VkPipelineLayoutCreateInfo pipelineLayoutParams = {
166         VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, // VkStructureType sType;
167         DE_NULL,                                       // const void* pNext;
168         0u,                                            // VkPipelineLayoutCreateFlags flags;
169         setLayoutCount,                                // uint32_t setLayoutCount;
170         descriptorSetLayoutPtr,                        // const VkDescriptorSetLayout* pSetLayouts;
171         pushConstantRangeCount,                        // uint32_t pushConstantRangeCount;
172         pPushConstantRanges,                           // const VkPushConstantRange* pPushConstantRanges;
173     };
174 
175     return createPipelineLayout(vk, device, &pipelineLayoutParams);
176 }
177 
getVkBuffer(const de::MovePtr<BufferWithMemory> & buffer)178 VkBuffer getVkBuffer(const de::MovePtr<BufferWithMemory> &buffer)
179 {
180     VkBuffer result = (buffer.get() == DE_NULL) ? DE_NULL : buffer->get();
181 
182     return result;
183 }
184 
makeStridedDeviceAddressRegion(const DeviceInterface & vkd,const VkDevice device,VkBuffer buffer,uint32_t stride,uint32_t count)185 VkStridedDeviceAddressRegionKHR makeStridedDeviceAddressRegion(const DeviceInterface &vkd, const VkDevice device,
186                                                                VkBuffer buffer, uint32_t stride, uint32_t count)
187 {
188     if (buffer == DE_NULL)
189     {
190         return makeStridedDeviceAddressRegionKHR(0, 0, 0);
191     }
192     else
193     {
194         return makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, buffer, 0), stride,
195                                                  stride * count);
196     }
197 }
198 
199 // Function replacing all occurrences of substring with string passed in last parameter.
replace(const std::string & str,const std::string & from,const std::string & to)200 static inline std::string replace(const std::string &str, const std::string &from, const std::string &to)
201 {
202     std::string result(str);
203 
204     size_t start_pos = 0;
205     while ((start_pos = result.find(from, start_pos)) != std::string::npos)
206     {
207         result.replace(start_pos, from.length(), to);
208         start_pos += to.length();
209     }
210 
211     return result;
212 }
213 
214 class RayTracingComplexControlFlowInstance : public TestInstance
215 {
216 public:
217     RayTracingComplexControlFlowInstance(Context &context, const CaseDef &data);
218     ~RayTracingComplexControlFlowInstance(void);
219     tcu::TestStatus iterate(void);
220 
221 protected:
222     void calcShaderGroup(uint32_t &shaderGroupCounter, const VkShaderStageFlags shaders1,
223                          const VkShaderStageFlags shaders2, const VkShaderStageFlags shaderStageFlags,
224                          uint32_t &shaderGroup, uint32_t &shaderGroupCount) const;
225     PushConstants getPushConstants(void) const;
226     std::vector<uint32_t> getExpectedValues(void) const;
227     de::MovePtr<BufferWithMemory> runTest(void);
228     Move<VkPipeline> makePipeline(de::MovePtr<RayTracingPipeline> &rayTracingPipeline, VkPipelineLayout pipelineLayout);
229     de::MovePtr<BufferWithMemory> createShaderBindingTable(const InstanceInterface &vki, const DeviceInterface &vkd,
230                                                            const VkDevice device, const VkPhysicalDevice physicalDevice,
231                                                            const VkPipeline pipeline, Allocator &allocator,
232                                                            de::MovePtr<RayTracingPipeline> &rayTracingPipeline,
233                                                            const uint32_t group, const uint32_t groupCount = 1u);
234     de::MovePtr<TopLevelAccelerationStructure> initTopAccelerationStructure(
235         VkCommandBuffer cmdBuffer,
236         vector<de::SharedPtr<BottomLevelAccelerationStructure>> &bottomLevelAccelerationStructures);
237     vector<de::SharedPtr<BottomLevelAccelerationStructure>> initBottomAccelerationStructures(VkCommandBuffer cmdBuffer);
238     de::MovePtr<BottomLevelAccelerationStructure> initBottomAccelerationStructure(VkCommandBuffer cmdBuffer,
239                                                                                   tcu::UVec2 &startPos);
240 
241 private:
242     CaseDef m_data;
243     VkShaderStageFlags m_shaders;
244     VkShaderStageFlags m_shaders2;
245     uint32_t m_raygenShaderGroup;
246     uint32_t m_missShaderGroup;
247     uint32_t m_hitShaderGroup;
248     uint32_t m_callableShaderGroup;
249     uint32_t m_raygenShaderGroupCount;
250     uint32_t m_missShaderGroupCount;
251     uint32_t m_hitShaderGroupCount;
252     uint32_t m_callableShaderGroupCount;
253     uint32_t m_shaderGroupCount;
254     uint32_t m_depth;
255     PushConstants m_pushConstants;
256 };
257 
RayTracingComplexControlFlowInstance(Context & context,const CaseDef & data)258 RayTracingComplexControlFlowInstance::RayTracingComplexControlFlowInstance(Context &context, const CaseDef &data)
259     : vkt::TestInstance(context)
260     , m_data(data)
261     , m_shaders(0)
262     , m_shaders2(0)
263     , m_raygenShaderGroup(~0u)
264     , m_missShaderGroup(~0u)
265     , m_hitShaderGroup(~0u)
266     , m_callableShaderGroup(~0u)
267     , m_raygenShaderGroupCount(0)
268     , m_missShaderGroupCount(0)
269     , m_hitShaderGroupCount(0)
270     , m_callableShaderGroupCount(0)
271     , m_shaderGroupCount(0)
272     , m_depth(16)
273     , m_pushConstants(getPushConstants())
274 {
275     const VkShaderStageFlags hitStages =
276         VK_SHADER_STAGE_ANY_HIT_BIT_KHR | VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
277     BinaryCollection &collection = m_context.getBinaryCollection();
278     uint32_t shaderCount         = 0;
279 
280     if (collection.contains("rgen"))
281         m_shaders |= VK_SHADER_STAGE_RAYGEN_BIT_KHR;
282     if (collection.contains("ahit"))
283         m_shaders |= VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
284     if (collection.contains("chit"))
285         m_shaders |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
286     if (collection.contains("miss"))
287         m_shaders |= VK_SHADER_STAGE_MISS_BIT_KHR;
288     if (collection.contains("sect"))
289         m_shaders |= VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
290     if (collection.contains("call"))
291         m_shaders |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
292 
293     if (collection.contains("ahit2"))
294         m_shaders2 |= VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
295     if (collection.contains("chit2"))
296         m_shaders2 |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
297     if (collection.contains("miss2"))
298         m_shaders2 |= VK_SHADER_STAGE_MISS_BIT_KHR;
299     if (collection.contains("sect2"))
300         m_shaders2 |= VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
301 
302     if (collection.contains("cal0"))
303         m_shaders2 |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
304 
305     for (BinaryCollection::Iterator it = collection.begin(); it != collection.end(); ++it)
306         shaderCount++;
307 
308     if (shaderCount != (uint32_t)dePop32(m_shaders) + (uint32_t)dePop32(m_shaders2))
309         TCU_THROW(InternalError, "Unused shaders detected in the collection");
310 
311     calcShaderGroup(m_shaderGroupCount, m_shaders, m_shaders2, VK_SHADER_STAGE_RAYGEN_BIT_KHR, m_raygenShaderGroup,
312                     m_raygenShaderGroupCount);
313     calcShaderGroup(m_shaderGroupCount, m_shaders, m_shaders2, VK_SHADER_STAGE_MISS_BIT_KHR, m_missShaderGroup,
314                     m_missShaderGroupCount);
315     calcShaderGroup(m_shaderGroupCount, m_shaders, m_shaders2, hitStages, m_hitShaderGroup, m_hitShaderGroupCount);
316     calcShaderGroup(m_shaderGroupCount, m_shaders, m_shaders2, VK_SHADER_STAGE_CALLABLE_BIT_KHR, m_callableShaderGroup,
317                     m_callableShaderGroupCount);
318 }
319 
~RayTracingComplexControlFlowInstance(void)320 RayTracingComplexControlFlowInstance::~RayTracingComplexControlFlowInstance(void)
321 {
322 }
323 
calcShaderGroup(uint32_t & shaderGroupCounter,const VkShaderStageFlags shaders1,const VkShaderStageFlags shaders2,const VkShaderStageFlags shaderStageFlags,uint32_t & shaderGroup,uint32_t & shaderGroupCount) const324 void RayTracingComplexControlFlowInstance::calcShaderGroup(uint32_t &shaderGroupCounter,
325                                                            const VkShaderStageFlags shaders1,
326                                                            const VkShaderStageFlags shaders2,
327                                                            const VkShaderStageFlags shaderStageFlags,
328                                                            uint32_t &shaderGroup, uint32_t &shaderGroupCount) const
329 {
330     const uint32_t shader1Count = ((shaders1 & shaderStageFlags) != 0) ? 1 : 0;
331     const uint32_t shader2Count = ((shaders2 & shaderStageFlags) != 0) ? 1 : 0;
332 
333     shaderGroupCount = shader1Count + shader2Count;
334 
335     if (shaderGroupCount != 0)
336     {
337         shaderGroup = shaderGroupCounter;
338         shaderGroupCounter += shaderGroupCount;
339     }
340 }
341 
makePipeline(de::MovePtr<RayTracingPipeline> & rayTracingPipeline,VkPipelineLayout pipelineLayout)342 Move<VkPipeline> RayTracingComplexControlFlowInstance::makePipeline(de::MovePtr<RayTracingPipeline> &rayTracingPipeline,
343                                                                     VkPipelineLayout pipelineLayout)
344 {
345     const DeviceInterface &vkd       = m_context.getDeviceInterface();
346     const VkDevice device            = m_context.getDevice();
347     vk::BinaryCollection &collection = m_context.getBinaryCollection();
348 
349     if (0 != (m_shaders & VK_SHADER_STAGE_RAYGEN_BIT_KHR))
350         rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR,
351                                       createShaderModule(vkd, device, collection.get("rgen"), 0), m_raygenShaderGroup);
352     if (0 != (m_shaders & VK_SHADER_STAGE_ANY_HIT_BIT_KHR))
353         rayTracingPipeline->addShader(VK_SHADER_STAGE_ANY_HIT_BIT_KHR,
354                                       createShaderModule(vkd, device, collection.get("ahit"), 0), m_hitShaderGroup);
355     if (0 != (m_shaders & VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR))
356         rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,
357                                       createShaderModule(vkd, device, collection.get("chit"), 0), m_hitShaderGroup);
358     if (0 != (m_shaders & VK_SHADER_STAGE_MISS_BIT_KHR))
359         rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR,
360                                       createShaderModule(vkd, device, collection.get("miss"), 0), m_missShaderGroup);
361     if (0 != (m_shaders & VK_SHADER_STAGE_INTERSECTION_BIT_KHR))
362         rayTracingPipeline->addShader(VK_SHADER_STAGE_INTERSECTION_BIT_KHR,
363                                       createShaderModule(vkd, device, collection.get("sect"), 0), m_hitShaderGroup);
364     if (0 != (m_shaders & VK_SHADER_STAGE_CALLABLE_BIT_KHR))
365         rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,
366                                       createShaderModule(vkd, device, collection.get("call"), 0),
367                                       m_callableShaderGroup + 1);
368 
369     if (0 != (m_shaders2 & VK_SHADER_STAGE_CALLABLE_BIT_KHR))
370         rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,
371                                       createShaderModule(vkd, device, collection.get("cal0"), 0),
372                                       m_callableShaderGroup);
373     if (0 != (m_shaders2 & VK_SHADER_STAGE_ANY_HIT_BIT_KHR))
374         rayTracingPipeline->addShader(VK_SHADER_STAGE_ANY_HIT_BIT_KHR,
375                                       createShaderModule(vkd, device, collection.get("ahit2"), 0),
376                                       m_hitShaderGroup + 1);
377     if (0 != (m_shaders2 & VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR))
378         rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,
379                                       createShaderModule(vkd, device, collection.get("chit2"), 0),
380                                       m_hitShaderGroup + 1);
381     if (0 != (m_shaders2 & VK_SHADER_STAGE_MISS_BIT_KHR))
382         rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR,
383                                       createShaderModule(vkd, device, collection.get("miss2"), 0),
384                                       m_missShaderGroup + 1);
385     if (0 != (m_shaders2 & VK_SHADER_STAGE_INTERSECTION_BIT_KHR))
386         rayTracingPipeline->addShader(VK_SHADER_STAGE_INTERSECTION_BIT_KHR,
387                                       createShaderModule(vkd, device, collection.get("sect2"), 0),
388                                       m_hitShaderGroup + 1);
389 
390     if (m_data.testOp == TEST_OP_TRACE_RAY && m_data.stage != VK_SHADER_STAGE_RAYGEN_BIT_KHR)
391         rayTracingPipeline->setMaxRecursionDepth(2);
392 
393     Move<VkPipeline> pipeline = rayTracingPipeline->createPipeline(vkd, device, pipelineLayout);
394 
395     return pipeline;
396 }
397 
createShaderBindingTable(const InstanceInterface & vki,const DeviceInterface & vkd,const VkDevice device,const VkPhysicalDevice physicalDevice,const VkPipeline pipeline,Allocator & allocator,de::MovePtr<RayTracingPipeline> & rayTracingPipeline,const uint32_t group,const uint32_t groupCount)398 de::MovePtr<BufferWithMemory> RayTracingComplexControlFlowInstance::createShaderBindingTable(
399     const InstanceInterface &vki, const DeviceInterface &vkd, const VkDevice device,
400     const VkPhysicalDevice physicalDevice, const VkPipeline pipeline, Allocator &allocator,
401     de::MovePtr<RayTracingPipeline> &rayTracingPipeline, const uint32_t group, const uint32_t groupCount)
402 {
403     de::MovePtr<BufferWithMemory> shaderBindingTable;
404 
405     if (group < m_shaderGroupCount)
406     {
407         const uint32_t shaderGroupHandleSize    = getShaderGroupSize(vki, physicalDevice);
408         const uint32_t shaderGroupBaseAlignment = getShaderGroupBaseAlignment(vki, physicalDevice);
409 
410         shaderBindingTable = rayTracingPipeline->createShaderBindingTable(
411             vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, group, groupCount);
412     }
413 
414     return shaderBindingTable;
415 }
416 
initTopAccelerationStructure(VkCommandBuffer cmdBuffer,vector<de::SharedPtr<BottomLevelAccelerationStructure>> & bottomLevelAccelerationStructures)417 de::MovePtr<TopLevelAccelerationStructure> RayTracingComplexControlFlowInstance::initTopAccelerationStructure(
418     VkCommandBuffer cmdBuffer,
419     vector<de::SharedPtr<BottomLevelAccelerationStructure>> &bottomLevelAccelerationStructures)
420 {
421     const DeviceInterface &vkd                        = m_context.getDeviceInterface();
422     const VkDevice device                             = m_context.getDevice();
423     Allocator &allocator                              = m_context.getDefaultAllocator();
424     de::MovePtr<TopLevelAccelerationStructure> result = makeTopLevelAccelerationStructure();
425 
426     result->setInstanceCount(bottomLevelAccelerationStructures.size());
427 
428     for (size_t structNdx = 0; structNdx < bottomLevelAccelerationStructures.size(); ++structNdx)
429         result->addInstance(bottomLevelAccelerationStructures[structNdx]);
430 
431     result->createAndBuild(vkd, device, cmdBuffer, allocator);
432 
433     return result;
434 }
435 
initBottomAccelerationStructure(VkCommandBuffer cmdBuffer,tcu::UVec2 & startPos)436 de::MovePtr<BottomLevelAccelerationStructure> RayTracingComplexControlFlowInstance::initBottomAccelerationStructure(
437     VkCommandBuffer cmdBuffer, tcu::UVec2 &startPos)
438 {
439     const DeviceInterface &vkd                           = m_context.getDeviceInterface();
440     const VkDevice device                                = m_context.getDevice();
441     Allocator &allocator                                 = m_context.getDefaultAllocator();
442     de::MovePtr<BottomLevelAccelerationStructure> result = makeBottomLevelAccelerationStructure();
443     const float z = (m_data.stage == VK_SHADER_STAGE_MISS_BIT_KHR) ? +1.0f : -1.0f;
444     std::vector<tcu::Vec3> geometryData;
445 
446     DE_UNREF(startPos);
447 
448     result->setGeometryCount(1);
449     geometryData.push_back(tcu::Vec3(0.0f, 0.0f, z));
450     geometryData.push_back(tcu::Vec3(1.0f, 1.0f, z));
451     result->addGeometry(geometryData, false);
452     result->createAndBuild(vkd, device, cmdBuffer, allocator);
453 
454     return result;
455 }
456 
457 vector<de::SharedPtr<BottomLevelAccelerationStructure>> RayTracingComplexControlFlowInstance::
initBottomAccelerationStructures(VkCommandBuffer cmdBuffer)458     initBottomAccelerationStructures(VkCommandBuffer cmdBuffer)
459 {
460     tcu::UVec2 startPos;
461     vector<de::SharedPtr<BottomLevelAccelerationStructure>> result;
462     de::MovePtr<BottomLevelAccelerationStructure> bottomLevelAccelerationStructure =
463         initBottomAccelerationStructure(cmdBuffer, startPos);
464 
465     result.push_back(de::SharedPtr<BottomLevelAccelerationStructure>(bottomLevelAccelerationStructure.release()));
466 
467     return result;
468 }
469 
getPushConstants(void) const470 PushConstants RayTracingComplexControlFlowInstance::getPushConstants(void) const
471 {
472     const uint32_t hitOfs = 1;
473     const uint32_t miss   = 1;
474     PushConstants result;
475 
476     switch (m_data.testType)
477     {
478     case TEST_TYPE_IF:
479     {
480         result = {32 | 8 | 1, 10000, 0x0F, 0xF0, hitOfs, miss};
481 
482         break;
483     }
484     case TEST_TYPE_LOOP:
485     {
486         result = {8, 10000, 0x0F, 100000, hitOfs, miss};
487 
488         break;
489     }
490     case TEST_TYPE_SWITCH:
491     {
492         result = {3, 10000, 0x07, 100000, hitOfs, miss};
493 
494         break;
495     }
496     case TEST_TYPE_LOOP_DOUBLE_CALL:
497     {
498         result = {7, 10000, 0x0F, 0xF0, hitOfs, miss};
499 
500         break;
501     }
502     case TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE:
503     {
504         result = {16, 5, 0x0F, 0xF0, hitOfs, miss};
505 
506         break;
507     }
508     case TEST_TYPE_NESTED_LOOP:
509     {
510         result = {8, 5, 0x0F, 0x09, hitOfs, miss};
511 
512         break;
513     }
514     case TEST_TYPE_NESTED_LOOP_BEFORE:
515     {
516         result = {9, 16, 0x0F, 10, hitOfs, miss};
517 
518         break;
519     }
520     case TEST_TYPE_NESTED_LOOP_AFTER:
521     {
522         result = {9, 16, 0x0F, 10, hitOfs, miss};
523 
524         break;
525     }
526     case TEST_TYPE_FUNCTION_CALL:
527     {
528         result = {0xFFB, 16, 10, 100000, hitOfs, miss};
529 
530         break;
531     }
532     case TEST_TYPE_NESTED_FUNCTION_CALL:
533     {
534         result = {0xFFB, 16, 10, 100000, hitOfs, miss};
535 
536         break;
537     }
538 
539     default:
540         TCU_THROW(InternalError, "Unknown testType");
541     }
542 
543     return result;
544 }
545 
runTest(void)546 de::MovePtr<BufferWithMemory> RayTracingComplexControlFlowInstance::runTest(void)
547 {
548     const InstanceInterface &vki          = m_context.getInstanceInterface();
549     const DeviceInterface &vkd            = m_context.getDeviceInterface();
550     const VkDevice device                 = m_context.getDevice();
551     const VkPhysicalDevice physicalDevice = m_context.getPhysicalDevice();
552     const uint32_t queueFamilyIndex       = m_context.getUniversalQueueFamilyIndex();
553     const VkQueue queue                   = m_context.getUniversalQueue();
554     Allocator &allocator                  = m_context.getDefaultAllocator();
555     const VkFormat format                 = VK_FORMAT_R32_UINT;
556     const uint32_t pushConstants[]        = {m_pushConstants.a, m_pushConstants.b,      m_pushConstants.c,
557                                              m_pushConstants.d, m_pushConstants.hitOfs, m_pushConstants.miss};
558     const uint32_t pushConstantsSize      = sizeof(pushConstants);
559     const uint32_t pixelCount             = m_data.width * m_data.height * m_depth;
560     const uint32_t shaderGroupHandleSize  = getShaderGroupSize(vki, physicalDevice);
561 
562     const Move<VkDescriptorSetLayout> descriptorSetLayout =
563         DescriptorSetLayoutBuilder()
564             .addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, ALL_RAY_TRACING_STAGES)
565             .addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, ALL_RAY_TRACING_STAGES)
566             .build(vkd, device);
567     const Move<VkDescriptorPool> descriptorPool =
568         DescriptorPoolBuilder()
569             .addType(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE)
570             .addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR)
571             .build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
572     const Move<VkDescriptorSet> descriptorSet = makeDescriptorSet(vkd, device, *descriptorPool, *descriptorSetLayout);
573     const Move<VkPipelineLayout> pipelineLayout =
574         makePipelineLayout(vkd, device, descriptorSetLayout.get(), pushConstantsSize);
575     const Move<VkCommandPool> cmdPool = createCommandPool(vkd, device, 0, queueFamilyIndex);
576     const Move<VkCommandBuffer> cmdBuffer =
577         allocateCommandBuffer(vkd, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
578 
579     de::MovePtr<RayTracingPipeline> rayTracingPipeline = de::newMovePtr<RayTracingPipeline>();
580     const Move<VkPipeline> pipeline                    = makePipeline(rayTracingPipeline, *pipelineLayout);
581     const de::MovePtr<BufferWithMemory> raygenShaderBindingTable =
582         createShaderBindingTable(vki, vkd, device, physicalDevice, *pipeline, allocator, rayTracingPipeline,
583                                  m_raygenShaderGroup, m_raygenShaderGroupCount);
584     const de::MovePtr<BufferWithMemory> missShaderBindingTable =
585         createShaderBindingTable(vki, vkd, device, physicalDevice, *pipeline, allocator, rayTracingPipeline,
586                                  m_missShaderGroup, m_missShaderGroupCount);
587     const de::MovePtr<BufferWithMemory> hitShaderBindingTable =
588         createShaderBindingTable(vki, vkd, device, physicalDevice, *pipeline, allocator, rayTracingPipeline,
589                                  m_hitShaderGroup, m_hitShaderGroupCount);
590     const de::MovePtr<BufferWithMemory> callableShaderBindingTable =
591         createShaderBindingTable(vki, vkd, device, physicalDevice, *pipeline, allocator, rayTracingPipeline,
592                                  m_callableShaderGroup, m_callableShaderGroupCount);
593 
594     const VkStridedDeviceAddressRegionKHR raygenShaderBindingTableRegion = makeStridedDeviceAddressRegion(
595         vkd, device, getVkBuffer(raygenShaderBindingTable), shaderGroupHandleSize, m_raygenShaderGroupCount);
596     const VkStridedDeviceAddressRegionKHR missShaderBindingTableRegion = makeStridedDeviceAddressRegion(
597         vkd, device, getVkBuffer(missShaderBindingTable), shaderGroupHandleSize, m_missShaderGroupCount);
598     const VkStridedDeviceAddressRegionKHR hitShaderBindingTableRegion = makeStridedDeviceAddressRegion(
599         vkd, device, getVkBuffer(hitShaderBindingTable), shaderGroupHandleSize, m_hitShaderGroupCount);
600     const VkStridedDeviceAddressRegionKHR callableShaderBindingTableRegion = makeStridedDeviceAddressRegion(
601         vkd, device, getVkBuffer(callableShaderBindingTable), shaderGroupHandleSize, m_callableShaderGroupCount);
602 
603     const VkImageCreateInfo imageCreateInfo = makeImageCreateInfo(m_data.width, m_data.height, m_depth, format);
604     const VkImageSubresourceRange imageSubresourceRange =
605         makeImageSubresourceRange(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 1u, 0, 1u);
606     const de::MovePtr<ImageWithMemory> image = de::MovePtr<ImageWithMemory>(
607         new ImageWithMemory(vkd, device, allocator, imageCreateInfo, MemoryRequirement::Any));
608     const Move<VkImageView> imageView =
609         makeImageView(vkd, device, **image, VK_IMAGE_VIEW_TYPE_3D, format, imageSubresourceRange);
610 
611     const VkBufferCreateInfo bufferCreateInfo =
612         makeBufferCreateInfo(pixelCount * sizeof(uint32_t), VK_BUFFER_USAGE_TRANSFER_DST_BIT);
613     const VkImageSubresourceLayers bufferImageSubresourceLayers =
614         makeImageSubresourceLayers(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 0u, 1u);
615     const VkBufferImageCopy bufferImageRegion =
616         makeBufferImageCopy(makeExtent3D(m_data.width, m_data.height, m_depth), bufferImageSubresourceLayers);
617     de::MovePtr<BufferWithMemory> buffer = de::MovePtr<BufferWithMemory>(
618         new BufferWithMemory(vkd, device, allocator, bufferCreateInfo, MemoryRequirement::HostVisible));
619 
620     const VkDescriptorImageInfo descriptorImageInfo =
621         makeDescriptorImageInfo(DE_NULL, *imageView, VK_IMAGE_LAYOUT_GENERAL);
622 
623     const VkImageMemoryBarrier preImageBarrier =
624         makeImageMemoryBarrier(0u, VK_ACCESS_TRANSFER_WRITE_BIT, VK_IMAGE_LAYOUT_UNDEFINED,
625                                VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, **image, imageSubresourceRange);
626     const VkImageMemoryBarrier postImageBarrier = makeImageMemoryBarrier(
627         VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_SHADER_READ_BIT, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
628         VK_IMAGE_LAYOUT_GENERAL, **image, imageSubresourceRange);
629     const VkMemoryBarrier preTraceMemoryBarrier =
630         makeMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
631     const VkMemoryBarrier postTraceMemoryBarrier =
632         makeMemoryBarrier(VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_TRANSFER_READ_BIT);
633     const VkMemoryBarrier postCopyMemoryBarrier =
634         makeMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
635     const VkClearValue clearValue = makeClearValueColorU32(DEFAULT_CLEAR_VALUE, 0u, 0u, 255u);
636 
637     vector<de::SharedPtr<BottomLevelAccelerationStructure>> bottomLevelAccelerationStructures;
638     de::MovePtr<TopLevelAccelerationStructure> topLevelAccelerationStructure;
639 
640     DE_ASSERT(DE_LENGTH_OF_ARRAY(pushConstants) == PUSH_CONSTANTS_COUNT);
641 
642     beginCommandBuffer(vkd, *cmdBuffer, 0u);
643     {
644         vkd.cmdPushConstants(*cmdBuffer, *pipelineLayout, ALL_RAY_TRACING_STAGES, 0, pushConstantsSize,
645                              &m_pushConstants);
646 
647         cmdPipelineImageMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT,
648                                       VK_PIPELINE_STAGE_TRANSFER_BIT, &preImageBarrier);
649         vkd.cmdClearColorImage(*cmdBuffer, **image, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, &clearValue.color, 1,
650                                &imageSubresourceRange);
651         cmdPipelineImageMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT,
652                                       VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR, &postImageBarrier);
653 
654         bottomLevelAccelerationStructures = initBottomAccelerationStructures(*cmdBuffer);
655         topLevelAccelerationStructure     = initTopAccelerationStructure(*cmdBuffer, bottomLevelAccelerationStructures);
656 
657         cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT,
658                                  VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR, &preTraceMemoryBarrier);
659 
660         const TopLevelAccelerationStructure *topLevelAccelerationStructurePtr = topLevelAccelerationStructure.get();
661         VkWriteDescriptorSetAccelerationStructureKHR accelerationStructureWriteDescriptorSet = {
662             VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR, //  VkStructureType sType;
663             DE_NULL,                                                           //  const void* pNext;
664             1u,                                                                //  uint32_t accelerationStructureCount;
665             topLevelAccelerationStructurePtr->getPtr(), //  const VkAccelerationStructureKHR* pAccelerationStructures;
666         };
667 
668         DescriptorSetUpdateBuilder()
669             .writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(0u),
670                          VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, &descriptorImageInfo)
671             .writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(1u),
672                          VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &accelerationStructureWriteDescriptorSet)
673             .update(vkd, device);
674 
675         vkd.cmdBindDescriptorSets(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipelineLayout, 0, 1,
676                                   &descriptorSet.get(), 0, DE_NULL);
677 
678         vkd.cmdBindPipeline(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipeline);
679 
680         cmdTraceRays(vkd, *cmdBuffer, &raygenShaderBindingTableRegion, &missShaderBindingTableRegion,
681                      &hitShaderBindingTableRegion, &callableShaderBindingTableRegion, m_data.width, m_data.height, 1);
682 
683         cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR,
684                                  VK_PIPELINE_STAGE_TRANSFER_BIT, &postTraceMemoryBarrier);
685 
686         vkd.cmdCopyImageToBuffer(*cmdBuffer, **image, VK_IMAGE_LAYOUT_GENERAL, **buffer, 1u, &bufferImageRegion);
687 
688         cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_HOST_BIT,
689                                  &postCopyMemoryBarrier);
690     }
691     endCommandBuffer(vkd, *cmdBuffer);
692 
693     submitCommandsAndWait(vkd, device, queue, cmdBuffer.get());
694 
695     invalidateMappedMemoryRange(vkd, device, buffer->getAllocation().getMemory(), buffer->getAllocation().getOffset(),
696                                 pixelCount * sizeof(uint32_t));
697 
698     return buffer;
699 }
700 
getExpectedValues(void) const701 std::vector<uint32_t> RayTracingComplexControlFlowInstance::getExpectedValues(void) const
702 {
703     const uint32_t plainSize       = m_data.width * m_data.height;
704     const uint32_t plain8Ofs       = 8 * plainSize;
705     const struct PushConstants &p  = m_pushConstants;
706     const uint32_t pushConstants[] = {0,
707                                       m_pushConstants.a,
708                                       m_pushConstants.b,
709                                       m_pushConstants.c,
710                                       m_pushConstants.d,
711                                       m_pushConstants.hitOfs,
712                                       m_pushConstants.miss};
713     const uint32_t resultSize      = plainSize * m_depth;
714     const bool fixed               = m_data.testOp == TEST_OP_REPORT_INTERSECTION;
715     std::vector<uint32_t> result(resultSize, DEFAULT_CLEAR_VALUE);
716     uint32_t v0;
717     uint32_t v1;
718     uint32_t v2;
719     uint32_t v3;
720 
721     switch (m_data.testType)
722     {
723     case TEST_TYPE_IF:
724     {
725         for (uint32_t id = 0; id < plainSize; ++id)
726         {
727             v2 = v3 = p.b;
728 
729             if ((p.a & id) != 0)
730             {
731                 v0 = p.c & id;
732                 v1 = (p.d & id) + 1;
733 
734                 result[plain8Ofs + id] = v0;
735                 if (!fixed)
736                     v0++;
737             }
738             else
739             {
740                 v0 = p.d & id;
741                 v1 = (p.c & id) + 1;
742 
743                 if (!fixed)
744                 {
745                     result[plain8Ofs + id] = v1;
746                     v1++;
747                 }
748                 else
749                     result[plain8Ofs + id] = v0;
750             }
751 
752             result[id] = v0 + v1 + v2 + v3;
753         }
754 
755         break;
756     }
757     case TEST_TYPE_LOOP:
758     {
759         for (uint32_t id = 0; id < plainSize; ++id)
760         {
761             result[id] = 0;
762 
763             v1 = v3 = p.b;
764 
765             for (uint32_t n = 0; n < p.a; n++)
766             {
767                 v0 = (p.c & id) + n;
768 
769                 result[((n % 8) + 8) * plainSize + id] = v0;
770                 if (!fixed)
771                     v0++;
772 
773                 result[id] += v0 + v1 + v3;
774             }
775         }
776 
777         break;
778     }
779     case TEST_TYPE_SWITCH:
780     {
781         for (uint32_t id = 0; id < plainSize; ++id)
782         {
783             switch (p.a & id)
784             {
785             case 0:
786             {
787                 v1 = v2 = v3 = p.b;
788                 v0           = p.c & id;
789                 break;
790             }
791             case 1:
792             {
793                 v0 = v2 = v3 = p.b;
794                 v1           = p.c & id;
795                 break;
796             }
797             case 2:
798             {
799                 v0 = v1 = v3 = p.b;
800                 v2           = p.c & id;
801                 break;
802             }
803             case 3:
804             {
805                 v0 = v1 = v2 = p.b;
806                 v3           = p.c & id;
807                 break;
808             }
809             default:
810             {
811                 v0 = v1 = v2 = v3 = 0;
812                 break;
813             }
814             }
815 
816             if (!fixed)
817                 result[plain8Ofs + id] = p.c & id;
818             else
819                 result[plain8Ofs + id] = v0;
820 
821             result[id] = v0 + v1 + v2 + v3;
822 
823             if (!fixed)
824                 result[id]++;
825         }
826 
827         break;
828     }
829     case TEST_TYPE_LOOP_DOUBLE_CALL:
830     {
831         for (uint32_t id = 0; id < plainSize; ++id)
832         {
833             result[id] = 0;
834 
835             v3 = p.b;
836 
837             for (uint32_t x = 0; x < p.a; x++)
838             {
839                 v0 = (p.c & id) + x;
840                 v1 = (p.d & id) + x + 1;
841 
842                 result[(((2 * x + 0) % 8) + 8) * plainSize + id] = v0;
843                 if (!fixed)
844                     v0++;
845 
846                 if (!fixed)
847                 {
848                     result[(((2 * x + 1) % 8) + 8) * plainSize + id] = v1;
849                     v1++;
850                 }
851 
852                 result[id] += v0 + v1 + v3;
853             }
854         }
855 
856         break;
857     }
858     case TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE:
859     {
860         for (uint32_t id = 0; id < plainSize; ++id)
861         {
862             result[id] = 0;
863 
864             v3 = p.a + p.b;
865 
866             for (uint32_t x = 0; x < p.a; x++)
867             {
868                 if ((x & p.b) != 0)
869                 {
870                     v0 = (p.c & id) + x;
871                     v1 = (p.d & id) + x + 1;
872 
873                     result[(((2 * x + 0) % 8) + 8) * plainSize + id] = v0;
874                     if (!fixed)
875                         v0++;
876 
877                     if (!fixed)
878                     {
879                         result[(((2 * x + 1) % 8) + 8) * plainSize + id] = v1;
880                         v1++;
881                     }
882 
883                     result[id] += v0 + v1 + v3;
884                 }
885             }
886         }
887 
888         break;
889     }
890     case TEST_TYPE_NESTED_LOOP:
891     {
892         for (uint32_t id = 0; id < plainSize; ++id)
893         {
894             result[id] = 0;
895 
896             v1 = v3 = p.b;
897 
898             for (uint32_t y = 0; y < p.a; y++)
899                 for (uint32_t x = 0; x < p.a; x++)
900                 {
901                     const uint32_t n = x + y * p.a;
902 
903                     if ((n & p.d) != 0)
904                     {
905                         v0 = (p.c & id) + n;
906 
907                         result[((n % 8) + 8) * plainSize + id] = v0;
908                         if (!fixed)
909                             v0++;
910 
911                         result[id] += v0 + v1 + v3;
912                     }
913                 }
914         }
915 
916         break;
917     }
918     case TEST_TYPE_NESTED_LOOP_BEFORE:
919     {
920         for (uint32_t id = 0; id < plainSize; ++id)
921         {
922             result[id] = 0;
923 
924             for (uint32_t y = 0; y < p.d; y++)
925                 for (uint32_t x = 0; x < p.d; x++)
926                 {
927                     if (((x + y * p.a) & p.b) != 0)
928                         result[id] += (x + y);
929                 }
930 
931             v1 = v3 = p.a;
932 
933             for (uint32_t x = 0; x < p.b; x++)
934             {
935                 if ((x & p.a) != 0)
936                 {
937                     v0 = p.c & id;
938 
939                     result[((x % 8) + 8) * plainSize + id] = v0;
940                     if (!fixed)
941                         v0++;
942 
943                     result[id] += v0 + v1 + v3;
944                 }
945             }
946         }
947 
948         break;
949     }
950     case TEST_TYPE_NESTED_LOOP_AFTER:
951     {
952         for (uint32_t id = 0; id < plainSize; ++id)
953         {
954             result[id] = 0;
955 
956             v1 = v3 = p.a;
957 
958             for (uint32_t x = 0; x < p.b; x++)
959             {
960                 if ((x & p.a) != 0)
961                 {
962                     v0 = p.c & id;
963 
964                     result[((x % 8) + 8) * plainSize + id] = v0;
965                     if (!fixed)
966                         v0++;
967 
968                     result[id] += v0 + v1 + v3;
969                 }
970             }
971 
972             for (uint32_t y = 0; y < p.d; y++)
973                 for (uint32_t x = 0; x < p.d; x++)
974                 {
975                     if (((x + y * p.a) & p.b) != 0)
976                         result[id] += (x + y);
977                 }
978         }
979 
980         break;
981     }
982     case TEST_TYPE_FUNCTION_CALL:
983     {
984         uint32_t a[42];
985 
986         for (uint32_t id = 0; id < plainSize; ++id)
987         {
988             uint32_t r = 0;
989             uint32_t i;
990 
991             v0 = p.a & id;
992             v1 = v3 = p.d;
993 
994             for (i = 0; i < DE_LENGTH_OF_ARRAY(a); i++)
995                 a[i] = p.c * i;
996 
997             result[plain8Ofs + id] = v0;
998             if (!fixed)
999                 v0++;
1000 
1001             for (i = 0; i < DE_LENGTH_OF_ARRAY(a); i++)
1002                 r += a[i];
1003 
1004             result[id] = (r + i) + v0 + v1 + v3;
1005         }
1006 
1007         break;
1008     }
1009     case TEST_TYPE_NESTED_FUNCTION_CALL:
1010     {
1011         uint32_t a[14];
1012         uint32_t b[256];
1013 
1014         for (uint32_t id = 0; id < plainSize; ++id)
1015         {
1016             uint32_t r = 0;
1017             uint32_t i;
1018             uint32_t t = 0;
1019             uint32_t j;
1020 
1021             v0 = p.a & id;
1022             v3 = p.d;
1023 
1024             for (j = 0; j < DE_LENGTH_OF_ARRAY(b); j++)
1025                 b[j] = p.c * j;
1026 
1027             v1 = p.b;
1028 
1029             for (i = 0; i < DE_LENGTH_OF_ARRAY(a); i++)
1030                 a[i] = p.c * i;
1031 
1032             result[plain8Ofs + id] = v0;
1033             if (!fixed)
1034                 v0++;
1035 
1036             for (i = 0; i < DE_LENGTH_OF_ARRAY(a); i++)
1037                 r += a[i];
1038 
1039             for (j = 0; j < DE_LENGTH_OF_ARRAY(b); j++)
1040                 t += b[j];
1041 
1042             result[id] = (r + i) + (t + j) + v0 + v1 + v3;
1043         }
1044 
1045         break;
1046     }
1047 
1048     default:
1049         TCU_THROW(InternalError, "Unknown testType");
1050     }
1051 
1052     {
1053         const uint32_t startOfs = 7 * plainSize;
1054 
1055         for (uint32_t n = 0; n < plainSize; ++n)
1056             result[startOfs + n] = n;
1057     }
1058 
1059     for (uint32_t z = 1; z < DE_LENGTH_OF_ARRAY(pushConstants); ++z)
1060     {
1061         const uint32_t startOfs     = z * plainSize;
1062         const uint32_t pushConstant = pushConstants[z];
1063 
1064         for (uint32_t n = 0; n < plainSize; ++n)
1065             result[startOfs + n] = pushConstant;
1066     }
1067 
1068     return result;
1069 }
1070 
iterate(void)1071 tcu::TestStatus RayTracingComplexControlFlowInstance::iterate(void)
1072 {
1073     const de::MovePtr<BufferWithMemory> buffer = runTest();
1074     const uint32_t *bufferPtr                  = (uint32_t *)buffer->getAllocation().getHostPtr();
1075     const vector<uint32_t> expected            = getExpectedValues();
1076     tcu::TestLog &log                          = m_context.getTestContext().getLog();
1077     uint32_t failures                          = 0;
1078     uint32_t pos                               = 0;
1079 
1080     for (uint32_t z = 0; z < m_depth; ++z)
1081         for (uint32_t y = 0; y < m_data.height; ++y)
1082             for (uint32_t x = 0; x < m_data.width; ++x)
1083             {
1084                 if (bufferPtr[pos] != expected[pos])
1085                     failures++;
1086 
1087                 ++pos;
1088             }
1089 
1090     if (failures != 0)
1091     {
1092         uint32_t pos0 = 0;
1093         uint32_t pos1 = 0;
1094         std::stringstream css;
1095 
1096         for (uint32_t z = 0; z < m_depth; ++z)
1097         {
1098             css << "z=" << z << std::endl;
1099 
1100             for (uint32_t y = 0; y < m_data.height; ++y)
1101             {
1102                 for (uint32_t x = 0; x < m_data.width; ++x)
1103                     css << std::setw(6) << bufferPtr[pos0++] << ' ';
1104 
1105                 css << "    ";
1106 
1107                 for (uint32_t x = 0; x < m_data.width; ++x)
1108                     css << std::setw(6) << expected[pos1++] << ' ';
1109 
1110                 css << std::endl;
1111             }
1112 
1113             css << std::endl;
1114         }
1115 
1116         log << tcu::TestLog::Message << css.str() << tcu::TestLog::EndMessage;
1117     }
1118 
1119     if (failures == 0)
1120         return tcu::TestStatus::pass("Pass");
1121     else
1122         return tcu::TestStatus::fail("failures=" + de::toString(failures));
1123 }
1124 
1125 class ComplexControlFlowTestCase : public TestCase
1126 {
1127 public:
1128     ComplexControlFlowTestCase(tcu::TestContext &context, const char *name, const CaseDef data);
1129     ~ComplexControlFlowTestCase(void);
1130 
1131     virtual void initPrograms(SourceCollections &programCollection) const;
1132     virtual TestInstance *createInstance(Context &context) const;
1133     virtual void checkSupport(Context &context) const;
1134 
1135 private:
1136     static inline const std::string getIntersectionPassthrough(void);
1137     static inline const std::string getMissPassthrough(void);
1138     static inline const std::string getHitPassthrough(void);
1139 
1140     CaseDef m_data;
1141 };
1142 
ComplexControlFlowTestCase(tcu::TestContext & context,const char * name,const CaseDef data)1143 ComplexControlFlowTestCase::ComplexControlFlowTestCase(tcu::TestContext &context, const char *name, const CaseDef data)
1144     : vkt::TestCase(context, name)
1145     , m_data(data)
1146 {
1147 }
1148 
~ComplexControlFlowTestCase(void)1149 ComplexControlFlowTestCase::~ComplexControlFlowTestCase(void)
1150 {
1151 }
1152 
checkSupport(Context & context) const1153 void ComplexControlFlowTestCase::checkSupport(Context &context) const
1154 {
1155     context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
1156 
1157     const VkPhysicalDeviceAccelerationStructureFeaturesKHR &accelerationStructureFeaturesKHR =
1158         context.getAccelerationStructureFeatures();
1159 
1160     if (accelerationStructureFeaturesKHR.accelerationStructure == false)
1161         TCU_THROW(TestError, "VK_KHR_ray_tracing_pipeline requires "
1162                              "VkPhysicalDeviceAccelerationStructureFeaturesKHR.accelerationStructure");
1163 
1164     context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
1165 
1166     const VkPhysicalDeviceRayTracingPipelineFeaturesKHR &rayTracingPipelineFeaturesKHR =
1167         context.getRayTracingPipelineFeatures();
1168 
1169     if (rayTracingPipelineFeaturesKHR.rayTracingPipeline == false)
1170         TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceRayTracingPipelineFeaturesKHR.rayTracingPipeline");
1171 
1172     const VkPhysicalDeviceRayTracingPipelinePropertiesKHR &rayTracingPipelinePropertiesKHR =
1173         context.getRayTracingPipelineProperties();
1174 
1175     if (m_data.testOp == TEST_OP_TRACE_RAY && m_data.stage != VK_SHADER_STAGE_RAYGEN_BIT_KHR)
1176     {
1177         if (rayTracingPipelinePropertiesKHR.maxRayRecursionDepth < 2)
1178             TCU_THROW(NotSupportedError,
1179                       "rayTracingPipelinePropertiesKHR.maxRayRecursionDepth is smaller than required");
1180     }
1181 }
1182 
getIntersectionPassthrough(void)1183 const std::string ComplexControlFlowTestCase::getIntersectionPassthrough(void)
1184 {
1185     const std::string intersectionPassthrough = "#version 460 core\n"
1186                                                 "#extension GL_EXT_nonuniform_qualifier : enable\n"
1187                                                 "#extension GL_EXT_ray_tracing : require\n"
1188                                                 "hitAttributeEXT vec3 hitAttribute;\n"
1189                                                 "\n"
1190                                                 "void main()\n"
1191                                                 "{\n"
1192                                                 "  reportIntersectionEXT(0.95f, 0u);\n"
1193                                                 "}\n";
1194 
1195     return intersectionPassthrough;
1196 }
1197 
getMissPassthrough(void)1198 const std::string ComplexControlFlowTestCase::getMissPassthrough(void)
1199 {
1200     const std::string missPassthrough = "#version 460 core\n"
1201                                         "#extension GL_EXT_nonuniform_qualifier : enable\n"
1202                                         "#extension GL_EXT_ray_tracing : require\n"
1203                                         "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1204                                         "\n"
1205                                         "void main()\n"
1206                                         "{\n"
1207                                         "}\n";
1208 
1209     return missPassthrough;
1210 }
1211 
getHitPassthrough(void)1212 const std::string ComplexControlFlowTestCase::getHitPassthrough(void)
1213 {
1214     const std::string hitPassthrough = "#version 460 core\n"
1215                                        "#extension GL_EXT_nonuniform_qualifier : enable\n"
1216                                        "#extension GL_EXT_ray_tracing : require\n"
1217                                        "hitAttributeEXT vec3 attribs;\n"
1218                                        "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1219                                        "\n"
1220                                        "void main()\n"
1221                                        "{\n"
1222                                        "}\n";
1223 
1224     return hitPassthrough;
1225 }
1226 
initPrograms(SourceCollections & programCollection) const1227 void ComplexControlFlowTestCase::initPrograms(SourceCollections &programCollection) const
1228 {
1229     const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
1230     const std::string calleeMainPart =
1231         "  uint z = (inValue.x % 8) + 8;\n"
1232         "  uint v = inValue.y;\n"
1233         "  uint n = gl_LaunchIDEXT.x + gl_LaunchSizeEXT.x * gl_LaunchIDEXT.y;\n"
1234         "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, z), uvec4(v, 0, 0, 1));\n"
1235         "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 7), uvec4(n, 0, 0, 1));\n";
1236     const std::string idTemplate = "$";
1237     const std::string shaderCallInstruction =
1238         (m_data.testOp == TEST_OP_EXECUTE_CALLABLE) ?
1239             "executeCallableEXT(0, " + idTemplate + ")" :
1240         (m_data.testOp == TEST_OP_TRACE_RAY) ?
1241             "traceRayEXT(as, 0, 0xFF, p.hitOfs, 0, p.miss, vec3((gl_LaunchIDEXT.x) + vec3(0.5f)) / "
1242             "vec3(gl_LaunchSizeEXT), 1.0f, vec3(0.0f, 0.0f, 1.0f), 100.0f, " +
1243                 idTemplate + ")" :
1244         (m_data.testOp == TEST_OP_REPORT_INTERSECTION) ? "reportIntersectionEXT(1.0f, 0u)" :
1245                                                          "TEST_OP_NOT_IMPLEMENTED_FAILURE";
1246     std::string declsPreMain        = "#version 460 core\n"
1247                                       "#extension GL_EXT_nonuniform_qualifier : enable\n"
1248                                       "#extension GL_EXT_ray_tracing : require\n"
1249                                       "\n"
1250                                       "layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1251                                       "layout(set = 0, binding = 1) uniform accelerationStructureEXT as;\n"
1252                                       "\n"
1253                                       "layout(push_constant) uniform TestParams\n"
1254                                       "{\n"
1255                                       "    uint a;\n"
1256                                       "    uint b;\n"
1257                                       "    uint c;\n"
1258                                       "    uint d;\n"
1259                                       "    uint hitOfs;\n"
1260                                       "    uint miss;\n"
1261                                       "} p;\n";
1262     std::string declsInMainBeforeOp = "  uint result = 0;\n"
1263                                       "  uint id = uint(gl_LaunchIDEXT.x + gl_LaunchSizeEXT.x * gl_LaunchIDEXT.y);\n";
1264     std::string declsInMainAfterOp =
1265         "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 0), uvec4(result, 0, 0, 1));\n"
1266         "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 1), uvec4(p.a, 0, 0, 1));\n"
1267         "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 2), uvec4(p.b, 0, 0, 1));\n"
1268         "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 3), uvec4(p.c, 0, 0, 1));\n"
1269         "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 4), uvec4(p.d, 0, 0, 1));\n"
1270         "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 5), uvec4(p.hitOfs, 0, 0, 1));\n"
1271         "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 6), uvec4(p.miss, 0, 0, 1));\n";
1272     std::string opInMain  = "";
1273     std::string opPreMain = "";
1274 
1275     DE_ASSERT(!declsPreMain.empty() && PUSH_CONSTANTS_COUNT == 6);
1276 
1277     switch (m_data.testType)
1278     {
1279     case TEST_TYPE_IF:
1280     {
1281         opInMain = "  v2 = v3 = uvec2(0, p.b);\n"
1282                    "\n"
1283                    "  if ((p.a & id) != 0)\n"
1284                    "      { v0 = uvec2(0, p.c & id); v1 = uvec2(0, (p.d & id) + 1);" +
1285                    replace(shaderCallInstruction, idTemplate, "0") +
1286                    "; }\n"
1287                    "  else\n"
1288                    "      { v0 = uvec2(0, p.d & id); v1 = uvec2(0, (p.c & id) + 1);" +
1289                    replace(shaderCallInstruction, idTemplate, "1") +
1290                    "; }\n"
1291                    "\n"
1292                    "  result = v0.y + v1.y + v2.y + v3.y;\n";
1293 
1294         break;
1295     }
1296     case TEST_TYPE_LOOP:
1297     {
1298         opInMain = "  v1 = v3 = uvec2(0, p.b);\n"
1299                    "\n"
1300                    "  for (uint x = 0; x < p.a; x++)\n"
1301                    "  {\n"
1302                    "    v0 = uvec2(x, (p.c & id) + x);\n"
1303                    "    " +
1304                    replace(shaderCallInstruction, idTemplate, "0") +
1305                    ";\n"
1306                    "    result += v0.y + v1.y + v3.y;\n"
1307                    "  }\n";
1308 
1309         break;
1310     }
1311     case TEST_TYPE_SWITCH:
1312     {
1313         opInMain = "  switch (p.a & id)\n"
1314                    "  {\n"
1315                    "    case 0: { v1 = v2 = v3 = uvec2(0, p.b); v0 = uvec2(0, p.c & id); " +
1316                    replace(shaderCallInstruction, idTemplate, "0") +
1317                    "; break; }\n"
1318                    "    case 1: { v0 = v2 = v3 = uvec2(0, p.b); v1 = uvec2(0, p.c & id); " +
1319                    replace(shaderCallInstruction, idTemplate, "1") +
1320                    "; break; }\n"
1321                    "    case 2: { v0 = v1 = v3 = uvec2(0, p.b); v2 = uvec2(0, p.c & id); " +
1322                    replace(shaderCallInstruction, idTemplate, "2") +
1323                    "; break; }\n"
1324                    "    case 3: { v0 = v1 = v2 = uvec2(0, p.b); v3 = uvec2(0, p.c & id); " +
1325                    replace(shaderCallInstruction, idTemplate, "3") +
1326                    "; break; }\n"
1327                    "    default: break;\n"
1328                    "  }\n"
1329                    "\n"
1330                    "  result = v0.y + v1.y + v2.y + v3.y;\n";
1331 
1332         break;
1333     }
1334     case TEST_TYPE_LOOP_DOUBLE_CALL:
1335     {
1336         opInMain = "  v3 = uvec2(0, p.b);\n"
1337                    "  for (uint x = 0; x < p.a; x++)\n"
1338                    "  {\n"
1339                    "    v0 = uvec2(2 * x + 0, (p.c & id) + x);\n"
1340                    "    v1 = uvec2(2 * x + 1, (p.d & id) + x + 1);\n"
1341                    "    " +
1342                    replace(shaderCallInstruction, idTemplate, "0") +
1343                    ";\n"
1344                    "    " +
1345                    replace(shaderCallInstruction, idTemplate, "1") +
1346                    ";\n"
1347                    "    result += v0.y + v1.y + v3.y;\n"
1348                    "  }\n";
1349 
1350         break;
1351     }
1352     case TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE:
1353     {
1354         opInMain = "  v3 = uvec2(0, p.a + p.b);\n"
1355                    "  for (uint x = 0; x < p.a; x++)\n"
1356                    "    if ((x & p.b) != 0)\n"
1357                    "    {\n"
1358                    "      v0 = uvec2(2 * x + 0, (p.c & id) + x + 0);\n"
1359                    "      v1 = uvec2(2 * x + 1, (p.d & id) + x + 1);\n"
1360                    "      " +
1361                    replace(shaderCallInstruction, idTemplate, "0") +
1362                    ";\n"
1363                    "      " +
1364                    replace(shaderCallInstruction, idTemplate, "1") +
1365                    ";\n"
1366                    "      result += v0.y + v1.y + v3.y;\n"
1367                    "    }\n"
1368                    "\n";
1369 
1370         break;
1371     }
1372     case TEST_TYPE_NESTED_LOOP:
1373     {
1374         opInMain = "  v1 = v3 = uvec2(0, p.b);\n"
1375                    "  for (uint y = 0; y < p.a; y++)\n"
1376                    "  for (uint x = 0; x < p.a; x++)\n"
1377                    "  {\n"
1378                    "    uint n = x + y * p.a;\n"
1379                    "    if ((n & p.d) != 0)\n"
1380                    "    {\n"
1381                    "      v0 = uvec2(n, (p.c & id) + (x + y * p.a));\n"
1382                    "      " +
1383                    replace(shaderCallInstruction, idTemplate, "0") +
1384                    ";\n"
1385                    "      result += v0.y + v1.y + v3.y;\n"
1386                    "    }\n"
1387                    "  }\n"
1388                    "\n";
1389 
1390         break;
1391     }
1392     case TEST_TYPE_NESTED_LOOP_BEFORE:
1393     {
1394         opInMain = "  for (uint y = 0; y < p.d; y++)\n"
1395                    "  for (uint x = 0; x < p.d; x++)\n"
1396                    "    if (((x + y * p.a) & p.b) != 0)\n"
1397                    "      result += (x + y);\n"
1398                    "\n"
1399                    "  v1 = v3 = uvec2(0, p.a);\n"
1400                    "\n"
1401                    "  for (uint x = 0; x < p.b; x++)\n"
1402                    "    if ((x & p.a) != 0)\n"
1403                    "    {\n"
1404                    "      v0 = uvec2(x, p.c & id);\n"
1405                    "      " +
1406                    replace(shaderCallInstruction, idTemplate, "0") +
1407                    ";\n"
1408                    "      result += v0.y + v1.y + v3.y;\n"
1409                    "    }\n";
1410 
1411         break;
1412     }
1413     case TEST_TYPE_NESTED_LOOP_AFTER:
1414     {
1415         opInMain = "  v1 = v3 = uvec2(0, p.a); \n"
1416                    "  for (uint x = 0; x < p.b; x++)\n"
1417                    "    if ((x & p.a) != 0)\n"
1418                    "    {\n"
1419                    "      v0 = uvec2(x, p.c & id);\n"
1420                    "      " +
1421                    replace(shaderCallInstruction, idTemplate, "0") +
1422                    ";\n"
1423                    "      result += v0.y + v1.y + v3.y;\n"
1424                    "    }\n"
1425                    "\n"
1426                    "  for (uint y = 0; y < p.d; y++)\n"
1427                    "  for (uint x = 0; x < p.d; x++)\n"
1428                    "    if (((x + y * p.a) & p.b) != 0)\n"
1429                    "      result += x + y;\n";
1430 
1431         break;
1432     }
1433     case TEST_TYPE_FUNCTION_CALL:
1434     {
1435         opPreMain = "uint f1(void)\n"
1436                     "{\n"
1437                     "  uint i, r = 0;\n"
1438                     "  uint a[42];\n"
1439                     "\n"
1440                     "  for (i = 0; i < a.length(); i++) a[i] = p.c * i;\n"
1441                     "\n"
1442                     "  " +
1443                     replace(shaderCallInstruction, idTemplate, "0") +
1444                     ";\n"
1445                     "\n"
1446                     "  for (i = 0; i < a.length(); i++) r += a[i];\n"
1447                     "\n"
1448                     "  return r + i;\n"
1449                     "}\n";
1450         opInMain = "  v0 = uvec2(0, p.a & id); v1 = v3 = uvec2(0, p.d);\n"
1451                    "  result = f1() + v0.y + v1.y + v3.y;\n";
1452 
1453         break;
1454     }
1455     case TEST_TYPE_NESTED_FUNCTION_CALL:
1456     {
1457         opPreMain = "uint f0(void)\n"
1458                     "{\n"
1459                     "  uint i, r = 0;\n"
1460                     "  uint a[14];\n"
1461                     "\n"
1462                     "  for (i = 0; i < a.length(); i++) a[i] = p.c * i;\n"
1463                     "\n"
1464                     "  " +
1465                     replace(shaderCallInstruction, idTemplate, "0") +
1466                     ";\n"
1467                     "\n"
1468                     "  for (i = 0; i < a.length(); i++) r += a[i];\n"
1469                     "\n"
1470                     "  return r + i;\n"
1471                     "}\n"
1472                     "\n"
1473                     "uint f1(void)\n"
1474                     "{\n"
1475                     "  uint j, t = 0;\n"
1476                     "  uint b[256];\n"
1477                     "\n"
1478                     "  for (j = 0; j < b.length(); j++) b[j] = p.c * j;\n"
1479                     "\n"
1480                     "  v1 = uvec2(0, p.b);\n"
1481                     "\n"
1482                     "  t += f0();\n"
1483                     "\n"
1484                     "  for (j = 0; j < b.length(); j++) t += b[j];\n"
1485                     "\n"
1486                     "  return t + j;\n"
1487                     "}\n";
1488         opInMain = "  v0 = uvec2(0, p.a & id); v3 = uvec2(0, p.d);\n"
1489                    "  result = f1() + v0.y + v1.y + v3.y;\n";
1490 
1491         break;
1492     }
1493 
1494     default:
1495         TCU_THROW(InternalError, "Unknown testType");
1496     }
1497 
1498     if (m_data.testOp == TEST_OP_EXECUTE_CALLABLE)
1499     {
1500         const std::string calleeShader = "#version 460 core\n"
1501                                          "#extension GL_EXT_nonuniform_qualifier : enable\n"
1502                                          "#extension GL_EXT_ray_tracing : require\n"
1503                                          "\n"
1504                                          "layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1505                                          "layout(location = 0) callableDataInEXT uvec2 inValue;\n"
1506                                          "\n"
1507                                          "void main()\n"
1508                                          "{\n" +
1509                                          calleeMainPart +
1510                                          "  inValue.y++;\n"
1511                                          "}\n";
1512 
1513         declsPreMain += "layout(location = 0) callableDataEXT uvec2 v0;\n"
1514                         "layout(location = 1) callableDataEXT uvec2 v1;\n"
1515                         "layout(location = 2) callableDataEXT uvec2 v2;\n"
1516                         "layout(location = 3) callableDataEXT uvec2 v3;\n"
1517                         "\n";
1518 
1519         switch (m_data.stage)
1520         {
1521         case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
1522         {
1523             std::stringstream css;
1524             css << declsPreMain << opPreMain << "\n"
1525                 << "void main()\n"
1526                 << "{\n"
1527                 << declsInMainBeforeOp << opInMain // executeCallableEXT
1528                 << declsInMainAfterOp << "}\n";
1529 
1530             programCollection.glslSources.add("rgen") << glu::RaygenSource(css.str()) << buildOptions;
1531             programCollection.glslSources.add("cal0") << glu::CallableSource(calleeShader) << buildOptions;
1532 
1533             break;
1534         }
1535 
1536         case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
1537         {
1538             programCollection.glslSources.add("rgen")
1539                 << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1540 
1541             std::stringstream css;
1542             css << declsPreMain << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1543                 << "hitAttributeEXT vec3 attribs;\n"
1544                 << "\n"
1545                 << opPreMain << "\n"
1546                 << "void main()\n"
1547                 << "{\n"
1548                 << declsInMainBeforeOp << opInMain // executeCallableEXT
1549                 << declsInMainAfterOp << "}\n";
1550 
1551             programCollection.glslSources.add("chit") << glu::ClosestHitSource(css.str()) << buildOptions;
1552             programCollection.glslSources.add("cal0") << glu::CallableSource(calleeShader) << buildOptions;
1553 
1554             programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1555             programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1556             programCollection.glslSources.add("sect")
1557                 << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1558 
1559             break;
1560         }
1561 
1562         case VK_SHADER_STAGE_MISS_BIT_KHR:
1563         {
1564             programCollection.glslSources.add("rgen")
1565                 << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1566 
1567             std::stringstream css;
1568             css << declsPreMain << opPreMain << "\n"
1569                 << "void main()\n"
1570                 << "{\n"
1571                 << declsInMainBeforeOp << opInMain // executeCallableEXT
1572                 << declsInMainAfterOp << "}\n";
1573 
1574             programCollection.glslSources.add("miss") << glu::MissSource(css.str()) << buildOptions;
1575             programCollection.glslSources.add("cal0") << glu::CallableSource(calleeShader) << buildOptions;
1576 
1577             programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1578             programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1579             programCollection.glslSources.add("sect")
1580                 << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1581 
1582             break;
1583         }
1584 
1585         case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
1586         {
1587             {
1588                 std::stringstream css;
1589                 css << "#version 460 core\n"
1590                     << "#extension GL_EXT_nonuniform_qualifier : enable\n"
1591                     << "#extension GL_EXT_ray_tracing : require\n"
1592                     << "\n"
1593                     << "layout(location = 4) callableDataEXT float dummy;\n"
1594                     << "layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1595                     << "\n"
1596                     << "void main()\n"
1597                     << "{\n"
1598                     << "  executeCallableEXT(1, 4);\n"
1599                     << "}\n";
1600 
1601                 programCollection.glslSources.add("rgen") << glu::RaygenSource(css.str()) << buildOptions;
1602             }
1603 
1604             {
1605                 std::stringstream css;
1606                 css << declsPreMain << "layout(location = 4) callableDataInEXT float dummyIn;\n"
1607                     << opPreMain << "\n"
1608                     << "void main()\n"
1609                     << "{\n"
1610                     << declsInMainBeforeOp << opInMain // executeCallableEXT
1611                     << declsInMainAfterOp << "}\n";
1612 
1613                 programCollection.glslSources.add("call") << glu::CallableSource(css.str()) << buildOptions;
1614             }
1615 
1616             programCollection.glslSources.add("cal0") << glu::CallableSource(calleeShader) << buildOptions;
1617 
1618             break;
1619         }
1620 
1621         default:
1622             TCU_THROW(InternalError, "Unknown stage");
1623         }
1624     }
1625     else if (m_data.testOp == TEST_OP_TRACE_RAY)
1626     {
1627         const std::string missShader = "#version 460 core\n"
1628                                        "#extension GL_EXT_nonuniform_qualifier : enable\n"
1629                                        "#extension GL_EXT_ray_tracing : require\n"
1630                                        "\n"
1631                                        "layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1632                                        "layout(location = 0) rayPayloadInEXT uvec2 inValue;\n"
1633                                        "\n"
1634                                        "void main()\n"
1635                                        "{\n" +
1636                                        calleeMainPart +
1637                                        "  inValue.y++;\n"
1638                                        "}\n";
1639 
1640         declsPreMain += "layout(location = 0) rayPayloadEXT uvec2 v0;\n"
1641                         "layout(location = 1) rayPayloadEXT uvec2 v1;\n"
1642                         "layout(location = 2) rayPayloadEXT uvec2 v2;\n"
1643                         "layout(location = 3) rayPayloadEXT uvec2 v3;\n";
1644 
1645         switch (m_data.stage)
1646         {
1647         case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
1648         {
1649             std::stringstream css;
1650             css << declsPreMain << opPreMain << "\n"
1651                 << "void main()\n"
1652                 << "{\n"
1653                 << declsInMainBeforeOp << opInMain // traceRayEXT
1654                 << declsInMainAfterOp << "}\n";
1655 
1656             programCollection.glslSources.add("rgen") << glu::RaygenSource(css.str()) << buildOptions;
1657 
1658             programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1659             programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1660             programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1661             programCollection.glslSources.add("sect")
1662                 << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1663 
1664             programCollection.glslSources.add("miss2") << glu::MissSource(missShader) << buildOptions;
1665             programCollection.glslSources.add("ahit2") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1666             programCollection.glslSources.add("chit2") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1667             programCollection.glslSources.add("sect2")
1668                 << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1669 
1670             break;
1671         }
1672 
1673         case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
1674         {
1675             programCollection.glslSources.add("rgen")
1676                 << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1677 
1678             std::stringstream css;
1679             css << declsPreMain << opPreMain << "\n"
1680                 << "void main()\n"
1681                 << "{\n"
1682                 << declsInMainBeforeOp << opInMain // traceRayEXT
1683                 << declsInMainAfterOp << "}\n";
1684 
1685             programCollection.glslSources.add("chit") << glu::ClosestHitSource(css.str()) << buildOptions;
1686 
1687             programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1688             programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1689             programCollection.glslSources.add("sect")
1690                 << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1691 
1692             programCollection.glslSources.add("miss2") << glu::MissSource(missShader) << buildOptions;
1693             programCollection.glslSources.add("ahit2") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1694             programCollection.glslSources.add("chit2") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1695             programCollection.glslSources.add("sect2")
1696                 << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1697 
1698             break;
1699         }
1700 
1701         case VK_SHADER_STAGE_MISS_BIT_KHR:
1702         {
1703             programCollection.glslSources.add("rgen")
1704                 << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1705 
1706             std::stringstream css;
1707             css << declsPreMain << opPreMain << "\n"
1708                 << "void main()\n"
1709                 << "{\n"
1710                 << declsInMainBeforeOp << opInMain // traceRayEXT
1711                 << declsInMainAfterOp << "}\n";
1712 
1713             programCollection.glslSources.add("miss") << glu::MissSource(css.str()) << buildOptions;
1714 
1715             programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1716             programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1717             programCollection.glslSources.add("sect")
1718                 << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1719 
1720             programCollection.glslSources.add("miss2") << glu::MissSource(missShader) << buildOptions;
1721             programCollection.glslSources.add("ahit2") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1722             programCollection.glslSources.add("chit2") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1723             programCollection.glslSources.add("sect2")
1724                 << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1725 
1726             break;
1727         }
1728 
1729         default:
1730             TCU_THROW(InternalError, "Unknown stage");
1731         }
1732     }
1733     else if (m_data.testOp == TEST_OP_REPORT_INTERSECTION)
1734     {
1735         const std::string anyHitShader = "#version 460 core\n"
1736                                          "#extension GL_EXT_nonuniform_qualifier : enable\n"
1737                                          "#extension GL_EXT_ray_tracing : require\n"
1738                                          "\n"
1739                                          "layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1740                                          "hitAttributeEXT block { uvec2 inValue; };\n"
1741                                          "\n"
1742                                          "void main()\n"
1743                                          "{\n" +
1744                                          calleeMainPart + "}\n";
1745 
1746         declsPreMain += "hitAttributeEXT block { uvec2 v0; };\n"
1747                         "uvec2 v1;\n"
1748                         "uvec2 v2;\n"
1749                         "uvec2 v3;\n";
1750 
1751         switch (m_data.stage)
1752         {
1753         case VK_SHADER_STAGE_INTERSECTION_BIT_KHR:
1754         {
1755             programCollection.glslSources.add("rgen")
1756                 << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1757 
1758             std::stringstream css;
1759             css << declsPreMain << opPreMain << "\n"
1760                 << "void main()\n"
1761                 << "{\n"
1762                 << declsInMainBeforeOp << opInMain // reportIntersectionEXT
1763                 << declsInMainAfterOp << "}\n";
1764 
1765             programCollection.glslSources.add("sect") << glu::IntersectionSource(css.str()) << buildOptions;
1766             programCollection.glslSources.add("ahit") << glu::AnyHitSource(anyHitShader) << buildOptions;
1767 
1768             programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1769             programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1770 
1771             break;
1772         }
1773 
1774         default:
1775             TCU_THROW(InternalError, "Unknown stage");
1776         }
1777     }
1778     else
1779     {
1780         TCU_THROW(InternalError, "Unknown operation");
1781     }
1782 }
1783 
createInstance(Context & context) const1784 TestInstance *ComplexControlFlowTestCase::createInstance(Context &context) const
1785 {
1786     return new RayTracingComplexControlFlowInstance(context, m_data);
1787 }
1788 
1789 } // namespace
1790 
createComplexControlFlowTests(tcu::TestContext & testCtx)1791 tcu::TestCaseGroup *createComplexControlFlowTests(tcu::TestContext &testCtx)
1792 {
1793     const VkShaderStageFlagBits R = VK_SHADER_STAGE_RAYGEN_BIT_KHR;
1794     const VkShaderStageFlagBits A = VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
1795     const VkShaderStageFlagBits C = VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
1796     const VkShaderStageFlagBits M = VK_SHADER_STAGE_MISS_BIT_KHR;
1797     const VkShaderStageFlagBits I = VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
1798     const VkShaderStageFlagBits L = VK_SHADER_STAGE_CALLABLE_BIT_KHR;
1799 
1800     DE_UNREF(A);
1801 
1802     static const struct
1803     {
1804         const char *name;
1805         VkShaderStageFlagBits stage;
1806     } testStages[]{
1807         {"rgen", VK_SHADER_STAGE_RAYGEN_BIT_KHR},  {"chit", VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR},
1808         {"ahit", VK_SHADER_STAGE_ANY_HIT_BIT_KHR}, {"sect", VK_SHADER_STAGE_INTERSECTION_BIT_KHR},
1809         {"miss", VK_SHADER_STAGE_MISS_BIT_KHR},    {"call", VK_SHADER_STAGE_CALLABLE_BIT_KHR},
1810     };
1811     static const struct
1812     {
1813         const char *name;
1814         TestOp op;
1815         VkShaderStageFlags applicableInStages;
1816     } testOps[]{
1817         {"execute_callable", TEST_OP_EXECUTE_CALLABLE, R | C | M | L},
1818         {"trace_ray", TEST_OP_TRACE_RAY, R | C | M},
1819         {"report_intersection", TEST_OP_REPORT_INTERSECTION, I},
1820     };
1821     static const struct
1822     {
1823         const char *name;
1824         TestType testType;
1825     } testTypes[]{
1826         {"if", TEST_TYPE_IF},
1827         {"loop", TEST_TYPE_LOOP},
1828         {"switch", TEST_TYPE_SWITCH},
1829         {"loop_double_call", TEST_TYPE_LOOP_DOUBLE_CALL},
1830         {"loop_double_call_sparse", TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE},
1831         {"nested_loop", TEST_TYPE_NESTED_LOOP},
1832         {"nested_loop_loop_before", TEST_TYPE_NESTED_LOOP_BEFORE},
1833         {"nested_loop_loop_after", TEST_TYPE_NESTED_LOOP_AFTER},
1834         {"function_call", TEST_TYPE_FUNCTION_CALL},
1835         {"nested_function_call", TEST_TYPE_NESTED_FUNCTION_CALL},
1836     };
1837 
1838     // Ray tracing complex control flow tests
1839     de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "complexcontrolflow"));
1840 
1841     for (size_t testTypeNdx = 0; testTypeNdx < DE_LENGTH_OF_ARRAY(testTypes); ++testTypeNdx)
1842     {
1843         const TestType testType = testTypes[testTypeNdx].testType;
1844         de::MovePtr<tcu::TestCaseGroup> testTypeGroup(new tcu::TestCaseGroup(testCtx, testTypes[testTypeNdx].name));
1845 
1846         for (size_t testOpNdx = 0; testOpNdx < DE_LENGTH_OF_ARRAY(testOps); ++testOpNdx)
1847         {
1848             const TestOp testOp = testOps[testOpNdx].op;
1849             de::MovePtr<tcu::TestCaseGroup> testOpGroup(new tcu::TestCaseGroup(testCtx, testOps[testOpNdx].name));
1850 
1851             for (size_t testStagesNdx = 0; testStagesNdx < DE_LENGTH_OF_ARRAY(testStages); ++testStagesNdx)
1852             {
1853                 const VkShaderStageFlagBits testStage = testStages[testStagesNdx].stage;
1854                 const std::string testName            = de::toString(testStages[testStagesNdx].name);
1855                 const uint32_t width                  = 4u;
1856                 const uint32_t height                 = 4u;
1857                 const CaseDef caseDef                 = {
1858                     testType,  //  TestType testType;
1859                     testOp,    //  TestOp testOp;
1860                     testStage, //  VkShaderStageFlagBits stage;
1861                     width,     //  uint32_t width;
1862                     height,    //  uint32_t height;
1863                 };
1864 
1865                 if ((testOps[testOpNdx].applicableInStages & static_cast<VkShaderStageFlags>(testStage)) == 0)
1866                     continue;
1867 
1868                 testOpGroup->addChild(new ComplexControlFlowTestCase(testCtx, testName.c_str(), caseDef));
1869             }
1870 
1871             testTypeGroup->addChild(testOpGroup.release());
1872         }
1873 
1874         group->addChild(testTypeGroup.release());
1875     }
1876 
1877     return group.release();
1878 }
1879 
1880 } // namespace RayTracing
1881 } // namespace vkt
1882