1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2020 The Khronos Group Inc.
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  *      http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  *
19  *//*!
20  * \file
21  * \brief Testing traversal control in ray tracing shaders
22  *//*--------------------------------------------------------------------*/
23 
24 #include "vktRayTracingTraversalControlTests.hpp"
25 
26 #include "vkDefs.hpp"
27 
28 #include "vktTestCase.hpp"
29 #include "vktTestGroupUtil.hpp"
30 #include "vkCmdUtil.hpp"
31 #include "vkObjUtil.hpp"
32 #include "vkBuilderUtil.hpp"
33 #include "vkBarrierUtil.hpp"
34 #include "vkBufferWithMemory.hpp"
35 #include "vkImageWithMemory.hpp"
36 #include "vkTypeUtil.hpp"
37 #include "vkImageUtil.hpp"
38 #include "deRandom.hpp"
39 #include "tcuTexture.hpp"
40 #include "tcuTextureUtil.hpp"
41 #include "tcuTestLog.hpp"
42 #include "tcuImageCompare.hpp"
43 
44 #include "vkRayTracingUtil.hpp"
45 
46 namespace vkt
47 {
48 namespace RayTracing
49 {
50 namespace
51 {
52 using namespace vk;
53 using namespace vkt;
54 
55 static const VkFlags ALL_RAY_TRACING_STAGES = VK_SHADER_STAGE_RAYGEN_BIT_KHR | VK_SHADER_STAGE_ANY_HIT_BIT_KHR |
56                                               VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_MISS_BIT_KHR |
57                                               VK_SHADER_STAGE_INTERSECTION_BIT_KHR | VK_SHADER_STAGE_CALLABLE_BIT_KHR;
58 
59 enum HitShaderTestType
60 {
61     HSTT_ISECT_REPORT_INTERSECTION      = 0,
62     HSTT_ISECT_DONT_REPORT_INTERSECTION = 1,
63     HSTT_AHIT_PASS_THROUGH              = 2,
64     HSTT_AHIT_IGNORE_INTERSECTION       = 3,
65     HSTT_AHIT_TERMINATE_RAY             = 4,
66     HSTT_COUNT
67 };
68 
69 enum BottomTestType
70 {
71     BTT_TRIANGLES,
72     BTT_AABBS
73 };
74 
75 const uint32_t TEST_WIDTH  = 8;
76 const uint32_t TEST_HEIGHT = 8;
77 
78 struct TestParams;
79 
80 class TestConfiguration
81 {
82 public:
~TestConfiguration()83     virtual ~TestConfiguration()
84     {
85     }
86 
87     virtual std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> initBottomAccelerationStructures(
88         Context &context, TestParams &testParams) = 0;
89     virtual de::MovePtr<TopLevelAccelerationStructure> initTopAccelerationStructure(
90         Context &context, TestParams &testParams,
91         std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> &bottomLevelAccelerationStructures)    = 0;
92     virtual void initRayTracingShaders(de::MovePtr<RayTracingPipeline> &rayTracingPipeline, Context &context,
93                                        TestParams &testParams)                                              = 0;
94     virtual void initShaderBindingTables(de::MovePtr<RayTracingPipeline> &rayTracingPipeline, Context &context,
95                                          TestParams &testParams, VkPipeline pipeline, uint32_t shaderGroupHandleSize,
96                                          uint32_t shaderGroupBaseAlignment,
97                                          de::MovePtr<BufferWithMemory> &raygenShaderBindingTable,
98                                          de::MovePtr<BufferWithMemory> &hitShaderBindingTable,
99                                          de::MovePtr<BufferWithMemory> &missShaderBindingTable,
100                                          de::MovePtr<BufferWithMemory> &callableShaderBindingTable,
101                                          VkStridedDeviceAddressRegionKHR &raygenShaderBindingTableRegion,
102                                          VkStridedDeviceAddressRegionKHR &hitShaderBindingTableRegion,
103                                          VkStridedDeviceAddressRegionKHR &missShaderBindingTableRegion,
104                                          VkStridedDeviceAddressRegionKHR &callableShaderBindingTableRegion) = 0;
105     virtual bool verifyImage(BufferWithMemory *resultBuffer, Context &context, TestParams &testParams)      = 0;
106     virtual VkFormat getResultImageFormat()                                                                 = 0;
107     virtual size_t getResultImageFormatSize()                                                               = 0;
108     virtual VkClearValue getClearValue()                                                                    = 0;
109 };
110 
111 struct TestParams
112 {
113     uint32_t width;
114     uint32_t height;
115     HitShaderTestType hitShaderTestType;
116     BottomTestType bottomType;
117     de::SharedPtr<TestConfiguration> testConfiguration;
118 };
119 
getShaderGroupHandleSize(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)120 uint32_t getShaderGroupHandleSize(const InstanceInterface &vki, const VkPhysicalDevice physicalDevice)
121 {
122     de::MovePtr<RayTracingProperties> rayTracingPropertiesKHR;
123 
124     rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
125     return rayTracingPropertiesKHR->getShaderGroupHandleSize();
126 }
127 
getShaderGroupBaseAlignment(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)128 uint32_t getShaderGroupBaseAlignment(const InstanceInterface &vki, const VkPhysicalDevice physicalDevice)
129 {
130     de::MovePtr<RayTracingProperties> rayTracingPropertiesKHR;
131 
132     rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
133     return rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
134 }
135 
makeImageCreateInfo(uint32_t width,uint32_t height,uint32_t depth,VkFormat format)136 VkImageCreateInfo makeImageCreateInfo(uint32_t width, uint32_t height, uint32_t depth, VkFormat format)
137 {
138     const VkImageCreateInfo imageCreateInfo = {
139         VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // VkStructureType sType;
140         DE_NULL,                             // const void* pNext;
141         (VkImageCreateFlags)0u,              // VkImageCreateFlags flags;
142         VK_IMAGE_TYPE_3D,                    // VkImageType imageType;
143         format,                              // VkFormat format;
144         makeExtent3D(width, height, depth),  // VkExtent3D extent;
145         1u,                                  // uint32_t mipLevels;
146         1u,                                  // uint32_t arrayLayers;
147         VK_SAMPLE_COUNT_1_BIT,               // VkSampleCountFlagBits samples;
148         VK_IMAGE_TILING_OPTIMAL,             // VkImageTiling tiling;
149         VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_TRANSFER_SRC_BIT |
150             VK_IMAGE_USAGE_TRANSFER_DST_BIT, // VkImageUsageFlags usage;
151         VK_SHARING_MODE_EXCLUSIVE,           // VkSharingMode sharingMode;
152         0u,                                  // uint32_t queueFamilyIndexCount;
153         DE_NULL,                             // const uint32_t* pQueueFamilyIndices;
154         VK_IMAGE_LAYOUT_UNDEFINED            // VkImageLayout initialLayout;
155     };
156 
157     return imageCreateInfo;
158 }
159 
160 class SingleSquareConfiguration : public TestConfiguration
161 {
162 public:
163     std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> initBottomAccelerationStructures(
164         Context &context, TestParams &testParams) override;
165     de::MovePtr<TopLevelAccelerationStructure> initTopAccelerationStructure(
166         Context &context, TestParams &testParams,
167         std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> &bottomLevelAccelerationStructures) override;
168     void initRayTracingShaders(de::MovePtr<RayTracingPipeline> &rayTracingPipeline, Context &context,
169                                TestParams &testParams) override;
170     void initShaderBindingTables(de::MovePtr<RayTracingPipeline> &rayTracingPipeline, Context &context,
171                                  TestParams &testParams, VkPipeline pipeline, uint32_t shaderGroupHandleSize,
172                                  uint32_t shaderGroupBaseAlignment,
173                                  de::MovePtr<BufferWithMemory> &raygenShaderBindingTable,
174                                  de::MovePtr<BufferWithMemory> &hitShaderBindingTable,
175                                  de::MovePtr<BufferWithMemory> &missShaderBindingTable,
176                                  de::MovePtr<BufferWithMemory> &callableShaderBindingTable,
177                                  VkStridedDeviceAddressRegionKHR &raygenShaderBindingTableRegion,
178                                  VkStridedDeviceAddressRegionKHR &hitShaderBindingTableRegion,
179                                  VkStridedDeviceAddressRegionKHR &missShaderBindingTableRegion,
180                                  VkStridedDeviceAddressRegionKHR &callableShaderBindingTableRegion) override;
181     bool verifyImage(BufferWithMemory *resultBuffer, Context &context, TestParams &testParams) override;
182     VkFormat getResultImageFormat() override;
183     size_t getResultImageFormatSize() override;
184     VkClearValue getClearValue() override;
185 };
186 
187 std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> SingleSquareConfiguration::
initBottomAccelerationStructures(Context & context,TestParams & testParams)188     initBottomAccelerationStructures(Context &context, TestParams &testParams)
189 {
190     DE_UNREF(context);
191 
192     std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> result;
193     de::MovePtr<BottomLevelAccelerationStructure> bottomLevelAccelerationStructure =
194         makeBottomLevelAccelerationStructure();
195     bottomLevelAccelerationStructure->setGeometryCount(1);
196 
197     de::SharedPtr<RaytracedGeometryBase> geometry;
198     if (testParams.bottomType == BTT_TRIANGLES)
199     {
200         tcu::Vec3 v0(1.0f, float(testParams.height) - 1.0f, 0.0f);
201         tcu::Vec3 v1(1.0f, 1.0f, 0.0f);
202         tcu::Vec3 v2(float(testParams.width) - 1.0f, float(testParams.height) - 1.0f, 0.0f);
203         tcu::Vec3 v3(float(testParams.width) - 1.0f, 1.0f, 0.0f);
204 
205         geometry =
206             makeRaytracedGeometry(VK_GEOMETRY_TYPE_TRIANGLES_KHR, VK_FORMAT_R32G32B32_SFLOAT, VK_INDEX_TYPE_NONE_KHR);
207         geometry->addVertex(v0);
208         geometry->addVertex(v1);
209         geometry->addVertex(v2);
210         geometry->addVertex(v2);
211         geometry->addVertex(v1);
212         geometry->addVertex(v3);
213     }
214     else // testParams.bottomType != BTT_TRIANGLES
215     {
216         tcu::Vec3 v0(1.0f, 1.0f, -0.1f);
217         tcu::Vec3 v1(float(testParams.width) - 1.0f, float(testParams.height) - 1.0f, 0.1f);
218 
219         geometry =
220             makeRaytracedGeometry(VK_GEOMETRY_TYPE_AABBS_KHR, VK_FORMAT_R32G32B32_SFLOAT, VK_INDEX_TYPE_NONE_KHR);
221         geometry->addVertex(v0);
222         geometry->addVertex(v1);
223     }
224     bottomLevelAccelerationStructure->addGeometry(geometry);
225 
226     result.push_back(de::SharedPtr<BottomLevelAccelerationStructure>(bottomLevelAccelerationStructure.release()));
227 
228     return result;
229 }
230 
initTopAccelerationStructure(Context & context,TestParams & testParams,std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> & bottomLevelAccelerationStructures)231 de::MovePtr<TopLevelAccelerationStructure> SingleSquareConfiguration::initTopAccelerationStructure(
232     Context &context, TestParams &testParams,
233     std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> &bottomLevelAccelerationStructures)
234 {
235     DE_UNREF(context);
236     DE_UNREF(testParams);
237 
238     de::MovePtr<TopLevelAccelerationStructure> result = makeTopLevelAccelerationStructure();
239     result->setInstanceCount(1);
240     result->addInstance(bottomLevelAccelerationStructures[0]);
241 
242     return result;
243 }
244 
initRayTracingShaders(de::MovePtr<RayTracingPipeline> & rayTracingPipeline,Context & context,TestParams & testParams)245 void SingleSquareConfiguration::initRayTracingShaders(de::MovePtr<RayTracingPipeline> &rayTracingPipeline,
246                                                       Context &context, TestParams &testParams)
247 {
248     const DeviceInterface &vkd = context.getDeviceInterface();
249     const VkDevice device      = context.getDevice();
250 
251     const std::vector<std::vector<std::string>> shaderNames = {
252         {"rgen", "isect_report", "ahit", "chit", "miss"},
253         {"rgen", "isect_pass_through", "ahit", "chit", "miss"},
254         {"rgen", "isect_report", "ahit_pass_through", "chit", "miss"},
255         {"rgen", "isect_report", "ahit_ignore", "chit", "miss"},
256         {"rgen", "isect_report", "ahit_terminate", "chit", "miss"},
257     };
258     rayTracingPipeline->addShader(
259         VK_SHADER_STAGE_RAYGEN_BIT_KHR,
260         createShaderModule(vkd, device, context.getBinaryCollection().get(shaderNames[testParams.hitShaderTestType][0]),
261                            0),
262         0);
263     if (testParams.bottomType == BTT_AABBS)
264         rayTracingPipeline->addShader(
265             VK_SHADER_STAGE_INTERSECTION_BIT_KHR,
266             createShaderModule(vkd, device,
267                                context.getBinaryCollection().get(shaderNames[testParams.hitShaderTestType][1]), 0),
268             1);
269     rayTracingPipeline->addShader(
270         VK_SHADER_STAGE_ANY_HIT_BIT_KHR,
271         createShaderModule(vkd, device, context.getBinaryCollection().get(shaderNames[testParams.hitShaderTestType][2]),
272                            0),
273         1);
274     rayTracingPipeline->addShader(
275         VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,
276         createShaderModule(vkd, device, context.getBinaryCollection().get(shaderNames[testParams.hitShaderTestType][3]),
277                            0),
278         1);
279     rayTracingPipeline->addShader(
280         VK_SHADER_STAGE_MISS_BIT_KHR,
281         createShaderModule(vkd, device, context.getBinaryCollection().get(shaderNames[testParams.hitShaderTestType][4]),
282                            0),
283         2);
284 }
285 
initShaderBindingTables(de::MovePtr<RayTracingPipeline> & rayTracingPipeline,Context & context,TestParams & testParams,VkPipeline pipeline,uint32_t shaderGroupHandleSize,uint32_t shaderGroupBaseAlignment,de::MovePtr<BufferWithMemory> & raygenShaderBindingTable,de::MovePtr<BufferWithMemory> & hitShaderBindingTable,de::MovePtr<BufferWithMemory> & missShaderBindingTable,de::MovePtr<BufferWithMemory> & callableShaderBindingTable,VkStridedDeviceAddressRegionKHR & raygenShaderBindingTableRegion,VkStridedDeviceAddressRegionKHR & hitShaderBindingTableRegion,VkStridedDeviceAddressRegionKHR & missShaderBindingTableRegion,VkStridedDeviceAddressRegionKHR & callableShaderBindingTableRegion)286 void SingleSquareConfiguration::initShaderBindingTables(
287     de::MovePtr<RayTracingPipeline> &rayTracingPipeline, Context &context, TestParams &testParams, VkPipeline pipeline,
288     uint32_t shaderGroupHandleSize, uint32_t shaderGroupBaseAlignment,
289     de::MovePtr<BufferWithMemory> &raygenShaderBindingTable, de::MovePtr<BufferWithMemory> &hitShaderBindingTable,
290     de::MovePtr<BufferWithMemory> &missShaderBindingTable, de::MovePtr<BufferWithMemory> &callableShaderBindingTable,
291     VkStridedDeviceAddressRegionKHR &raygenShaderBindingTableRegion,
292     VkStridedDeviceAddressRegionKHR &hitShaderBindingTableRegion,
293     VkStridedDeviceAddressRegionKHR &missShaderBindingTableRegion,
294     VkStridedDeviceAddressRegionKHR &callableShaderBindingTableRegion)
295 {
296     DE_UNREF(testParams);
297     DE_UNREF(callableShaderBindingTable);
298 
299     const DeviceInterface &vkd = context.getDeviceInterface();
300     const VkDevice device      = context.getDevice();
301     Allocator &allocator       = context.getDefaultAllocator();
302 
303     raygenShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
304         vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
305     hitShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
306         vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
307     missShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
308         vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
309 
310     raygenShaderBindingTableRegion =
311         makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0),
312                                           shaderGroupHandleSize, shaderGroupHandleSize);
313     hitShaderBindingTableRegion =
314         makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0),
315                                           shaderGroupHandleSize, shaderGroupHandleSize);
316     missShaderBindingTableRegion =
317         makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missShaderBindingTable->get(), 0),
318                                           shaderGroupHandleSize, shaderGroupHandleSize);
319     callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
320 }
321 
verifyImage(BufferWithMemory * resultBuffer,Context & context,TestParams & testParams)322 bool SingleSquareConfiguration::verifyImage(BufferWithMemory *resultBuffer, Context &context, TestParams &testParams)
323 {
324     // create result image
325     tcu::TextureFormat imageFormat = vk::mapVkFormat(getResultImageFormat());
326     tcu::ConstPixelBufferAccess resultAccess(imageFormat, testParams.width, testParams.height, 2,
327                                              resultBuffer->getAllocation().getHostPtr());
328 
329     // create reference image
330     std::vector<uint32_t> reference(testParams.width * testParams.height * 2);
331     tcu::PixelBufferAccess referenceAccess(imageFormat, testParams.width, testParams.height, 2, reference.data());
332 
333     // clear reference image with hit and miss values
334     // Reference image has two layers:
335     //   - ahit shader writes results to layer 0
336     //   - chit shader writes results to layer 1
337     //   - miss shader writes results to layer 0
338     //   - rays that missed on layer 0 - should have value 0 on layer 1
339     tcu::UVec4 missValue0 = tcu::UVec4(4, 0, 0, 0);
340     tcu::UVec4 missValue1 = tcu::UVec4(0, 0, 0, 0);
341     tcu::UVec4 hitValue0, hitValue1;
342     switch (testParams.hitShaderTestType)
343     {
344     case HSTT_ISECT_REPORT_INTERSECTION:
345         hitValue0 = tcu::UVec4(1, 0, 0, 0); // ahit returns 1
346         hitValue1 = tcu::UVec4(3, 0, 0, 0); // chit returns 3
347         break;
348     case HSTT_ISECT_DONT_REPORT_INTERSECTION:
349         hitValue0 = missValue0; // no ahit - results should report miss value
350         hitValue1 = missValue1; // no chit - results should report miss value
351         break;
352     case HSTT_AHIT_PASS_THROUGH:
353         hitValue0 = tcu::UVec4(0, 0, 0, 0); // empty ahit shader. Initial value from rgen written to result
354         hitValue1 = tcu::UVec4(3, 0, 0, 0); // chit returns 3
355         break;
356     case HSTT_AHIT_IGNORE_INTERSECTION:
357         hitValue0 = missValue0; // ahit ignores intersection - results should report miss value
358         hitValue1 = missValue1; // no chit - results should report miss value
359         break;
360     case HSTT_AHIT_TERMINATE_RAY:
361         hitValue0 = tcu::UVec4(
362             1, 0, 0, 0); // ahit should return 1. If it returned 2, then terminateRayEXT did not terminate ahit shader
363         hitValue1 = tcu::UVec4(3, 0, 0, 0); // chit returns 3
364         break;
365     default:
366         TCU_THROW(InternalError, "Wrong shader test type");
367     }
368 
369     tcu::clear(referenceAccess, missValue0);
370     for (uint32_t y = 0; y < testParams.width; ++y)
371         for (uint32_t x = 0; x < testParams.height; ++x)
372             referenceAccess.setPixel(missValue1, x, y, 1);
373 
374     for (uint32_t y = 1; y < testParams.width - 1; ++y)
375         for (uint32_t x = 1; x < testParams.height - 1; ++x)
376         {
377             referenceAccess.setPixel(hitValue0, x, y, 0);
378             referenceAccess.setPixel(hitValue1, x, y, 1);
379         }
380 
381     // compare result and reference
382     return tcu::intThresholdCompare(context.getTestContext().getLog(), "Result comparison", "", referenceAccess,
383                                     resultAccess, tcu::UVec4(0), tcu::COMPARE_LOG_RESULT);
384 }
385 
getResultImageFormat()386 VkFormat SingleSquareConfiguration::getResultImageFormat()
387 {
388     return VK_FORMAT_R32_UINT;
389 }
390 
getResultImageFormatSize()391 size_t SingleSquareConfiguration::getResultImageFormatSize()
392 {
393     return sizeof(uint32_t);
394 }
395 
getClearValue()396 VkClearValue SingleSquareConfiguration::getClearValue()
397 {
398     return makeClearValueColorU32(0xFF, 0u, 0u, 0u);
399 }
400 
401 class TraversalControlTestCase : public TestCase
402 {
403 public:
404     TraversalControlTestCase(tcu::TestContext &context, const char *name, const TestParams data);
405     ~TraversalControlTestCase(void);
406 
407     virtual void checkSupport(Context &context) const;
408     virtual void initPrograms(SourceCollections &programCollection) const;
409     virtual TestInstance *createInstance(Context &context) const;
410 
411 private:
412     TestParams m_data;
413 };
414 
415 class TraversalControlTestInstance : public TestInstance
416 {
417 public:
418     TraversalControlTestInstance(Context &context, const TestParams &data);
419     ~TraversalControlTestInstance(void);
420     tcu::TestStatus iterate(void);
421 
422 protected:
423     de::MovePtr<BufferWithMemory> runTest();
424 
425 private:
426     TestParams m_data;
427 };
428 
TraversalControlTestCase(tcu::TestContext & context,const char * name,const TestParams data)429 TraversalControlTestCase::TraversalControlTestCase(tcu::TestContext &context, const char *name, const TestParams data)
430     : vkt::TestCase(context, name)
431     , m_data(data)
432 {
433 }
434 
~TraversalControlTestCase(void)435 TraversalControlTestCase::~TraversalControlTestCase(void)
436 {
437 }
438 
checkSupport(Context & context) const439 void TraversalControlTestCase::checkSupport(Context &context) const
440 {
441     context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
442     context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
443 
444     const VkPhysicalDeviceRayTracingPipelineFeaturesKHR &rayTracingPipelineFeaturesKHR =
445         context.getRayTracingPipelineFeatures();
446     if (rayTracingPipelineFeaturesKHR.rayTracingPipeline == false)
447         TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceRayTracingPipelineFeaturesKHR.rayTracingPipeline");
448 
449     const VkPhysicalDeviceAccelerationStructureFeaturesKHR &accelerationStructureFeaturesKHR =
450         context.getAccelerationStructureFeatures();
451     if (accelerationStructureFeaturesKHR.accelerationStructure == false)
452         TCU_THROW(TestError, "VK_KHR_ray_tracing_pipeline requires "
453                              "VkPhysicalDeviceAccelerationStructureFeaturesKHR.accelerationStructure");
454 }
455 
initPrograms(SourceCollections & programCollection) const456 void TraversalControlTestCase::initPrograms(SourceCollections &programCollection) const
457 {
458     const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
459     {
460         std::stringstream css;
461         css << "#version 460 core\n"
462                "#extension GL_EXT_ray_tracing : require\n"
463                "layout(location = 0) rayPayloadEXT uvec4 hitValue;\n"
464                "layout(r32ui, set = 0, binding = 0) uniform uimage3D result;\n"
465                "layout(set = 0, binding = 1) uniform accelerationStructureEXT topLevelAS;\n"
466                "\n"
467                "void main()\n"
468                "{\n"
469                "  float tmin     = 0.0;\n"
470                "  float tmax     = 1.0;\n"
471                "  vec3  origin   = vec3(float(gl_LaunchIDEXT.x) + 0.5f, float(gl_LaunchIDEXT.y) + 0.5f, 0.5f);\n"
472                "  vec3  direct   = vec3(0.0, 0.0, -1.0);\n"
473                "  hitValue       = uvec4(0,0,0,0);\n"
474                "  traceRayEXT(topLevelAS, 0, 0xFF, 0, 0, 0, origin, tmin, direct, tmax, 0);\n"
475                "  imageStore(result, ivec3(gl_LaunchIDEXT.xy, 0), uvec4(hitValue.x, 0, 0, 0));\n"
476                "  imageStore(result, ivec3(gl_LaunchIDEXT.xy, 1), uvec4(hitValue.y, 0, 0, 0));\n"
477                "}\n";
478         programCollection.glslSources.add("rgen") << glu::RaygenSource(updateRayTracingGLSL(css.str())) << buildOptions;
479     }
480 
481     {
482         std::stringstream css;
483         css << "#version 460 core\n"
484                "#extension GL_EXT_ray_tracing : require\n"
485                "hitAttributeEXT uvec4 hitAttribute;\n"
486                "void main()\n"
487                "{\n"
488                "  hitAttribute = uvec4(0,0,0,0);\n"
489                "  reportIntersectionEXT(0.5f, 0);\n"
490                "}\n";
491 
492         programCollection.glslSources.add("isect_report")
493             << glu::IntersectionSource(updateRayTracingGLSL(css.str())) << buildOptions;
494     }
495 
496     {
497         std::stringstream css;
498         css << "#version 460 core\n"
499                "#extension GL_EXT_ray_tracing : require\n"
500                "void main()\n"
501                "{\n"
502                "}\n";
503 
504         programCollection.glslSources.add("isect_pass_through")
505             << glu::IntersectionSource(updateRayTracingGLSL(css.str())) << buildOptions;
506     }
507 
508     {
509         std::stringstream css;
510         css << "#version 460 core\n"
511                "#extension GL_EXT_ray_tracing : require\n"
512                "layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
513                "void main()\n"
514                "{\n"
515                "  hitValue.x = 1;\n"
516                "}\n";
517 
518         programCollection.glslSources.add("ahit") << glu::AnyHitSource(updateRayTracingGLSL(css.str())) << buildOptions;
519     }
520 
521     {
522         std::stringstream css;
523         css << "#version 460 core\n"
524                "#extension GL_EXT_ray_tracing : require\n"
525                "void main()\n"
526                "{\n"
527                "}\n";
528 
529         programCollection.glslSources.add("ahit_pass_through")
530             << glu::AnyHitSource(updateRayTracingGLSL(css.str())) << buildOptions;
531     }
532 
533     {
534         std::stringstream css;
535         css << "#version 460 core\n"
536                "#extension GL_EXT_ray_tracing : require\n"
537                "layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
538                "void main()\n"
539                "{\n"
540                "  hitValue.x = 1;\n"
541                "  ignoreIntersectionEXT;\n"
542                "  hitValue.x = 2;\n"
543                "}\n";
544 
545         programCollection.glslSources.add("ahit_ignore")
546             << glu::AnyHitSource(updateRayTracingGLSL(css.str())) << buildOptions;
547     }
548 
549     {
550         std::stringstream css;
551         css << "#version 460 core\n"
552                "#extension GL_EXT_ray_tracing : require\n"
553                "layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
554                "void main()\n"
555                "{\n"
556                "  hitValue.x = 1;\n"
557                "  terminateRayEXT;\n"
558                "  hitValue.x = 2;\n"
559                "}\n";
560 
561         programCollection.glslSources.add("ahit_terminate")
562             << glu::AnyHitSource(updateRayTracingGLSL(css.str())) << buildOptions;
563     }
564 
565     {
566         std::stringstream css;
567         css << "#version 460 core\n"
568                "#extension GL_EXT_ray_tracing : require\n"
569                "layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
570                "void main()\n"
571                "{\n"
572                "  hitValue.y = 3;\n"
573                "}\n";
574 
575         programCollection.glslSources.add("chit")
576             << glu::ClosestHitSource(updateRayTracingGLSL(css.str())) << buildOptions;
577     }
578 
579     {
580         std::stringstream css;
581         css << "#version 460 core\n"
582                "#extension GL_EXT_ray_tracing : require\n"
583                "layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
584                "void main()\n"
585                "{\n"
586                "  hitValue.x = 4;\n"
587                "}\n";
588 
589         programCollection.glslSources.add("miss") << glu::MissSource(updateRayTracingGLSL(css.str())) << buildOptions;
590     }
591 }
592 
createInstance(Context & context) const593 TestInstance *TraversalControlTestCase::createInstance(Context &context) const
594 {
595     return new TraversalControlTestInstance(context, m_data);
596 }
597 
TraversalControlTestInstance(Context & context,const TestParams & data)598 TraversalControlTestInstance::TraversalControlTestInstance(Context &context, const TestParams &data)
599     : vkt::TestInstance(context)
600     , m_data(data)
601 {
602 }
603 
~TraversalControlTestInstance(void)604 TraversalControlTestInstance::~TraversalControlTestInstance(void)
605 {
606 }
607 
runTest()608 de::MovePtr<BufferWithMemory> TraversalControlTestInstance::runTest()
609 {
610     const InstanceInterface &vki          = m_context.getInstanceInterface();
611     const DeviceInterface &vkd            = m_context.getDeviceInterface();
612     const VkDevice device                 = m_context.getDevice();
613     const VkPhysicalDevice physicalDevice = m_context.getPhysicalDevice();
614     const uint32_t queueFamilyIndex       = m_context.getUniversalQueueFamilyIndex();
615     const VkQueue queue                   = m_context.getUniversalQueue();
616     Allocator &allocator                  = m_context.getDefaultAllocator();
617 
618     const Move<VkDescriptorSetLayout> descriptorSetLayout =
619         DescriptorSetLayoutBuilder()
620             .addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, ALL_RAY_TRACING_STAGES)
621             .addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, ALL_RAY_TRACING_STAGES)
622             .build(vkd, device);
623     const Move<VkDescriptorPool> descriptorPool =
624         DescriptorPoolBuilder()
625             .addType(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE)
626             .addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR)
627             .build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
628     const Move<VkDescriptorSet> descriptorSet   = makeDescriptorSet(vkd, device, *descriptorPool, *descriptorSetLayout);
629     const Move<VkPipelineLayout> pipelineLayout = makePipelineLayout(vkd, device, descriptorSetLayout.get());
630 
631     de::MovePtr<RayTracingPipeline> rayTracingPipeline = de::newMovePtr<RayTracingPipeline>();
632     m_data.testConfiguration->initRayTracingShaders(rayTracingPipeline, m_context, m_data);
633     Move<VkPipeline> pipeline = rayTracingPipeline->createPipeline(vkd, device, *pipelineLayout);
634 
635     de::MovePtr<BufferWithMemory> raygenShaderBindingTable;
636     de::MovePtr<BufferWithMemory> hitShaderBindingTable;
637     de::MovePtr<BufferWithMemory> missShaderBindingTable;
638     de::MovePtr<BufferWithMemory> callableShaderBindingTable;
639     VkStridedDeviceAddressRegionKHR raygenShaderBindingTableRegion;
640     VkStridedDeviceAddressRegionKHR hitShaderBindingTableRegion;
641     VkStridedDeviceAddressRegionKHR missShaderBindingTableRegion;
642     VkStridedDeviceAddressRegionKHR callableShaderBindingTableRegion;
643     m_data.testConfiguration->initShaderBindingTables(
644         rayTracingPipeline, m_context, m_data, *pipeline, getShaderGroupHandleSize(vki, physicalDevice),
645         getShaderGroupBaseAlignment(vki, physicalDevice), raygenShaderBindingTable, hitShaderBindingTable,
646         missShaderBindingTable, callableShaderBindingTable, raygenShaderBindingTableRegion, hitShaderBindingTableRegion,
647         missShaderBindingTableRegion, callableShaderBindingTableRegion);
648 
649     const VkFormat imageFormat              = m_data.testConfiguration->getResultImageFormat();
650     const VkImageCreateInfo imageCreateInfo = makeImageCreateInfo(m_data.width, m_data.height, 2, imageFormat);
651     const VkImageSubresourceRange imageSubresourceRange =
652         makeImageSubresourceRange(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 1u, 0u, 1u);
653     const de::MovePtr<ImageWithMemory> image = de::MovePtr<ImageWithMemory>(
654         new ImageWithMemory(vkd, device, allocator, imageCreateInfo, MemoryRequirement::Any));
655     const Move<VkImageView> imageView =
656         makeImageView(vkd, device, **image, VK_IMAGE_VIEW_TYPE_3D, imageFormat, imageSubresourceRange);
657 
658     const VkBufferCreateInfo resultBufferCreateInfo =
659         makeBufferCreateInfo(m_data.width * m_data.height * 2 * m_data.testConfiguration->getResultImageFormatSize(),
660                              VK_BUFFER_USAGE_TRANSFER_DST_BIT);
661     const VkImageSubresourceLayers resultBufferImageSubresourceLayers =
662         makeImageSubresourceLayers(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 0u, 1u);
663     const VkBufferImageCopy resultBufferImageRegion =
664         makeBufferImageCopy(makeExtent3D(m_data.width, m_data.height, 2), resultBufferImageSubresourceLayers);
665     de::MovePtr<BufferWithMemory> resultBuffer = de::MovePtr<BufferWithMemory>(
666         new BufferWithMemory(vkd, device, allocator, resultBufferCreateInfo, MemoryRequirement::HostVisible));
667 
668     const VkDescriptorImageInfo descriptorImageInfo =
669         makeDescriptorImageInfo(DE_NULL, *imageView, VK_IMAGE_LAYOUT_GENERAL);
670 
671     const Move<VkCommandPool> cmdPool = createCommandPool(vkd, device, 0, queueFamilyIndex);
672     const Move<VkCommandBuffer> cmdBuffer =
673         allocateCommandBuffer(vkd, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
674 
675     std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> bottomLevelAccelerationStructures;
676     de::MovePtr<TopLevelAccelerationStructure> topLevelAccelerationStructure;
677 
678     beginCommandBuffer(vkd, *cmdBuffer, 0u);
679     {
680         const VkImageMemoryBarrier preImageBarrier =
681             makeImageMemoryBarrier(0u, VK_ACCESS_TRANSFER_WRITE_BIT, VK_IMAGE_LAYOUT_UNDEFINED,
682                                    VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, **image, imageSubresourceRange);
683         cmdPipelineImageMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT,
684                                       VK_PIPELINE_STAGE_TRANSFER_BIT, &preImageBarrier);
685 
686         const VkClearValue clearValue = m_data.testConfiguration->getClearValue();
687         vkd.cmdClearColorImage(*cmdBuffer, **image, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, &clearValue.color, 1,
688                                &imageSubresourceRange);
689 
690         const VkImageMemoryBarrier postImageBarrier = makeImageMemoryBarrier(
691             VK_ACCESS_TRANSFER_WRITE_BIT,
692             VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR,
693             VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, VK_IMAGE_LAYOUT_GENERAL, **image, imageSubresourceRange);
694         cmdPipelineImageMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT,
695                                       VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR, &postImageBarrier);
696 
697         bottomLevelAccelerationStructures =
698             m_data.testConfiguration->initBottomAccelerationStructures(m_context, m_data);
699         for (auto &blas : bottomLevelAccelerationStructures)
700             blas->createAndBuild(vkd, device, *cmdBuffer, allocator);
701         topLevelAccelerationStructure = m_data.testConfiguration->initTopAccelerationStructure(
702             m_context, m_data, bottomLevelAccelerationStructures);
703         topLevelAccelerationStructure->createAndBuild(vkd, device, *cmdBuffer, allocator);
704 
705         const TopLevelAccelerationStructure *topLevelAccelerationStructurePtr = topLevelAccelerationStructure.get();
706         VkWriteDescriptorSetAccelerationStructureKHR accelerationStructureWriteDescriptorSet = {
707             VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR, //  VkStructureType sType;
708             DE_NULL,                                                           //  const void* pNext;
709             1u,                                                                //  uint32_t accelerationStructureCount;
710             topLevelAccelerationStructurePtr->getPtr(), //  const VkAccelerationStructureKHR* pAccelerationStructures;
711         };
712 
713         DescriptorSetUpdateBuilder()
714             .writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(0u),
715                          VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, &descriptorImageInfo)
716             .writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(1u),
717                          VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &accelerationStructureWriteDescriptorSet)
718             .update(vkd, device);
719 
720         vkd.cmdBindDescriptorSets(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipelineLayout, 0, 1,
721                                   &descriptorSet.get(), 0, DE_NULL);
722 
723         vkd.cmdBindPipeline(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipeline);
724 
725         cmdTraceRays(vkd, *cmdBuffer, &raygenShaderBindingTableRegion, &missShaderBindingTableRegion,
726                      &hitShaderBindingTableRegion, &callableShaderBindingTableRegion, m_data.width, m_data.height, 1);
727 
728         const VkMemoryBarrier postTraceMemoryBarrier =
729             makeMemoryBarrier(VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_TRANSFER_READ_BIT);
730         const VkMemoryBarrier postCopyMemoryBarrier =
731             makeMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
732         cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR,
733                                  VK_PIPELINE_STAGE_TRANSFER_BIT, &postTraceMemoryBarrier);
734 
735         vkd.cmdCopyImageToBuffer(*cmdBuffer, **image, VK_IMAGE_LAYOUT_GENERAL, **resultBuffer, 1u,
736                                  &resultBufferImageRegion);
737 
738         cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_HOST_BIT,
739                                  &postCopyMemoryBarrier);
740     }
741     endCommandBuffer(vkd, *cmdBuffer);
742 
743     submitCommandsAndWait(vkd, device, queue, cmdBuffer.get());
744 
745     invalidateMappedMemoryRange(vkd, device, resultBuffer->getAllocation().getMemory(),
746                                 resultBuffer->getAllocation().getOffset(), VK_WHOLE_SIZE);
747 
748     return resultBuffer;
749 }
750 
iterate(void)751 tcu::TestStatus TraversalControlTestInstance::iterate(void)
752 {
753     // run test using arrays of pointers
754     const de::MovePtr<BufferWithMemory> buffer = runTest();
755 
756     if (!m_data.testConfiguration->verifyImage(buffer.get(), m_context, m_data))
757         return tcu::TestStatus::fail("Fail");
758     return tcu::TestStatus::pass("Pass");
759 }
760 
761 } // namespace
762 
createTraversalControlTests(tcu::TestContext & testCtx)763 tcu::TestCaseGroup *createTraversalControlTests(tcu::TestContext &testCtx)
764 {
765     // Tests verifying traversal control in RT hit shaders
766     de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "traversal_control"));
767 
768     struct HitShaderTestTypeData
769     {
770         HitShaderTestType shaderTestType;
771         bool onlyAabbTest;
772         const char *name;
773     } hitShaderTestTypes[] = {
774         {HSTT_ISECT_REPORT_INTERSECTION, true, "isect_report_intersection"},
775         {HSTT_ISECT_DONT_REPORT_INTERSECTION, true, "isect_dont_report_intersection"},
776         {HSTT_AHIT_PASS_THROUGH, false, "ahit_pass_through"},
777         {HSTT_AHIT_IGNORE_INTERSECTION, false, "ahit_ignore_intersection"},
778         {HSTT_AHIT_TERMINATE_RAY, false, "ahit_terminate_ray"},
779     };
780 
781     struct
782     {
783         BottomTestType testType;
784         const char *name;
785     } bottomTestTypes[] = {
786         {BTT_TRIANGLES, "triangles"},
787         {BTT_AABBS, "aabbs"},
788     };
789 
790     for (size_t shaderTestNdx = 0; shaderTestNdx < DE_LENGTH_OF_ARRAY(hitShaderTestTypes); ++shaderTestNdx)
791     {
792         de::MovePtr<tcu::TestCaseGroup> testTypeGroup(
793             new tcu::TestCaseGroup(group->getTestContext(), hitShaderTestTypes[shaderTestNdx].name));
794 
795         for (size_t testTypeNdx = 0; testTypeNdx < DE_LENGTH_OF_ARRAY(bottomTestTypes); ++testTypeNdx)
796         {
797             if (hitShaderTestTypes[shaderTestNdx].onlyAabbTest && bottomTestTypes[testTypeNdx].testType != BTT_AABBS)
798                 continue;
799 
800             TestParams testParams{TEST_WIDTH, TEST_HEIGHT, hitShaderTestTypes[shaderTestNdx].shaderTestType,
801                                   bottomTestTypes[testTypeNdx].testType,
802                                   de::SharedPtr<TestConfiguration>(new SingleSquareConfiguration())};
803             testTypeGroup->addChild(
804                 new TraversalControlTestCase(group->getTestContext(), bottomTestTypes[testTypeNdx].name, testParams));
805         }
806         group->addChild(testTypeGroup.release());
807     }
808 
809     return group.release();
810 }
811 
812 } // namespace RayTracing
813 
814 } // namespace vkt
815