1 /*-------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2022 The Khronos Group Inc.
6  * Copyright (c) 2022 NVIDIA Corporation.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  *//*!
21  * \file
22  * \brief Ray Tracing Position Fetch Tests
23  *//*--------------------------------------------------------------------*/
24 
25 #include "vktRayTracingPositionFetchTests.hpp"
26 #include "vktTestCase.hpp"
27 
28 #include "vkRayTracingUtil.hpp"
29 #include "vkObjUtil.hpp"
30 #include "vkCmdUtil.hpp"
31 #include "vkBufferWithMemory.hpp"
32 #include "vkBuilderUtil.hpp"
33 #include "vkTypeUtil.hpp"
34 #include "vkBarrierUtil.hpp"
35 
36 #include "deUniquePtr.hpp"
37 
38 #include "tcuVectorUtil.hpp"
39 
40 #include <sstream>
41 #include <vector>
42 #include <iostream>
43 
44 namespace vkt
45 {
46 namespace RayTracing
47 {
48 
49 namespace
50 {
51 
52 using namespace vk;
53 
54 enum TestFlagBits
55 {
56     TEST_FLAG_BIT_INSTANCE_TRANSFORM = 1U << 0,
57     TEST_FLAG_BIT_LAST               = 1U << 1
58 };
59 
60 std::vector<std::string> testFlagBitNames = {
61     "instance_transform",
62 };
63 
64 struct TestParams
65 {
66     vk::VkAccelerationStructureBuildTypeKHR buildType; // are we making AS on CPU or GPU
67     VkFormat vertexFormat;
68     uint32_t testFlagMask;
69 };
70 
71 class PositionFetchCase : public TestCase
72 {
73 public:
74     PositionFetchCase(tcu::TestContext &testCtx, const std::string &name, const TestParams &params);
~PositionFetchCase(void)75     virtual ~PositionFetchCase(void)
76     {
77     }
78 
79     virtual void checkSupport(Context &context) const;
80     virtual void initPrograms(vk::SourceCollections &programCollection) const;
81     virtual TestInstance *createInstance(Context &context) const;
82 
83 protected:
84     TestParams m_params;
85 };
86 
87 class PositionFetchInstance : public TestInstance
88 {
89 public:
90     PositionFetchInstance(Context &context, const TestParams &params);
~PositionFetchInstance(void)91     virtual ~PositionFetchInstance(void)
92     {
93     }
94 
95     virtual tcu::TestStatus iterate(void);
96 
97 protected:
98     TestParams m_params;
99 };
100 
PositionFetchCase(tcu::TestContext & testCtx,const std::string & name,const TestParams & params)101 PositionFetchCase::PositionFetchCase(tcu::TestContext &testCtx, const std::string &name, const TestParams &params)
102     : TestCase(testCtx, name)
103     , m_params(params)
104 {
105 }
106 
checkSupport(Context & context) const107 void PositionFetchCase::checkSupport(Context &context) const
108 {
109     context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
110     context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
111     context.requireDeviceFunctionality("VK_KHR_ray_tracing_position_fetch");
112 
113     const VkPhysicalDeviceAccelerationStructureFeaturesKHR &accelerationStructureFeaturesKHR =
114         context.getAccelerationStructureFeatures();
115     if (accelerationStructureFeaturesKHR.accelerationStructure == false)
116         TCU_THROW(TestError,
117                   "VK_KHR_ray_query requires VkPhysicalDeviceAccelerationStructureFeaturesKHR.accelerationStructure");
118 
119     if (m_params.buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_HOST_KHR &&
120         accelerationStructureFeaturesKHR.accelerationStructureHostCommands == false)
121         TCU_THROW(NotSupportedError,
122                   "Requires VkPhysicalDeviceAccelerationStructureFeaturesKHR.accelerationStructureHostCommands");
123 
124     const VkPhysicalDeviceRayTracingPositionFetchFeaturesKHR &rayTracingPositionFetchFeaturesKHR =
125         context.getRayTracingPositionFetchFeatures();
126     if (rayTracingPositionFetchFeaturesKHR.rayTracingPositionFetch == false)
127         TCU_THROW(NotSupportedError, "Requires VkPhysicalDevicePositionFetchFeaturesKHR.rayTracingPositionFetch");
128 
129     // Check supported vertex format.
130     checkAccelerationStructureVertexBufferFormat(context.getInstanceInterface(), context.getPhysicalDevice(),
131                                                  m_params.vertexFormat);
132 }
133 
initPrograms(vk::SourceCollections & programCollection) const134 void PositionFetchCase::initPrograms(vk::SourceCollections &programCollection) const
135 {
136     const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
137 
138     uint32_t numRays = 1; // XXX
139 
140     std::ostringstream layoutDecls;
141     layoutDecls << "layout(set=0, binding=0) uniform accelerationStructureEXT topLevelAS;\n"
142                 << "layout(set=0, binding=1, std430) buffer RayOrigins {\n"
143                 << "  vec4 values[" << numRays << "];\n"
144                 << "} origins;\n"
145                 << "layout(set=0, binding=2, std430) buffer OutputPositions {\n"
146                 << "  vec4 values[" << 6 * numRays << "];\n"
147                 << "} modes;\n";
148     const auto layoutDeclsStr = layoutDecls.str();
149 
150     std::ostringstream rgen;
151     rgen << "#version 460 core\n"
152          << "#extension GL_EXT_ray_tracing : require\n"
153          << "#extension GL_EXT_ray_tracing_position_fetch : require\n"
154          << "\n"
155          << "layout(location=0) rayPayloadEXT int value;\n"
156          << "\n"
157          << layoutDeclsStr << "\n"
158          << "void main()\n"
159          << "{\n"
160          << "  const uint  cullMask  = 0xFF;\n"
161          << "  const vec3  origin    = origins.values[gl_LaunchIDEXT.x].xyz;\n"
162          << "  const vec3  direction = vec3(0.0, 0.0, -1.0);\n"
163          << "  const float tMin      = 0.0;\n"
164          << "  const float tMax      = 2.0;\n"
165          << "  value                 = 0xFFFFFFFF;\n"
166          << "  traceRayEXT(topLevelAS, gl_RayFlagsNoneEXT, cullMask, 0, 0, 0, origin, tMin, direction, tMax, 0);\n"
167          << "}\n";
168 
169     std::ostringstream ah;
170     ah << "#version 460 core\n"
171        << "#extension GL_EXT_ray_tracing : require\n"
172        << "#extension GL_EXT_ray_tracing_position_fetch : require\n"
173        << "\n"
174        << layoutDeclsStr << "\n"
175        << "layout(location=0) rayPayloadEXT int value;\n"
176        << "\n"
177        << "void main()\n"
178        << "{\n"
179        << "  for (int i=0; i<3; i++) {\n"
180        << "    modes.values[6*gl_LaunchIDEXT.x+2*i] = vec4(gl_HitTriangleVertexPositionsEXT[i], 0.0);\n"
181        << "  }\n"
182        << "  terminateRayEXT;\n"
183        << "}\n";
184 
185     std::ostringstream ch;
186     ch << "#version 460 core\n"
187        << "#extension GL_EXT_ray_tracing : require\n"
188        << "#extension GL_EXT_ray_tracing_position_fetch : require\n"
189        << "\n"
190        << layoutDeclsStr << "\n"
191        << "layout(location=0) rayPayloadEXT int value;\n"
192        << "\n"
193        << "void main()\n"
194        << "{\n"
195        << "  for (int i=0; i<3; i++) {\n"
196        << "    modes.values[6*gl_LaunchIDEXT.x+2*i+1] = vec4(gl_HitTriangleVertexPositionsEXT[i], 0);\n"
197        << "  }\n"
198        << "}\n";
199 
200     // Should never miss to fill in with sentinel values to cause a failure
201     std::ostringstream miss;
202     miss << "#version 460 core\n"
203          << "#extension GL_EXT_ray_tracing : require\n"
204          << layoutDeclsStr << "\n"
205          << "layout(location=0) rayPayloadEXT int value;\n"
206          << "\n"
207          << "void main()\n"
208          << "{\n"
209          << "  for (int i=0; i<6; i++) {\n"
210          << "    modes.values[6*gl_LaunchIDEXT.x+i] = vec4(123.0f, 456.0f, 789.0f, 0.0f);\n"
211          << "  }\n"
212          << "}\n";
213 
214     programCollection.glslSources.add("rgen") << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
215     programCollection.glslSources.add("miss") << glu::MissSource(updateRayTracingGLSL(miss.str())) << buildOptions;
216     programCollection.glslSources.add("ah") << glu::AnyHitSource(updateRayTracingGLSL(ah.str())) << buildOptions;
217     programCollection.glslSources.add("ch") << glu::ClosestHitSource(updateRayTracingGLSL(ch.str())) << buildOptions;
218 }
219 
createInstance(Context & context) const220 TestInstance *PositionFetchCase::createInstance(Context &context) const
221 {
222     return new PositionFetchInstance(context, m_params);
223 }
224 
PositionFetchInstance(Context & context,const TestParams & params)225 PositionFetchInstance::PositionFetchInstance(Context &context, const TestParams &params)
226     : TestInstance(context)
227     , m_params(params)
228 {
229 }
230 
iterate(void)231 tcu::TestStatus PositionFetchInstance::iterate(void)
232 {
233     const auto &vki    = m_context.getInstanceInterface();
234     const auto physDev = m_context.getPhysicalDevice();
235     const auto &vkd    = m_context.getDeviceInterface();
236     const auto device  = m_context.getDevice();
237     auto &alloc        = m_context.getDefaultAllocator();
238     const auto qIndex  = m_context.getUniversalQueueFamilyIndex();
239     const auto queue   = m_context.getUniversalQueue();
240     const auto stages  = VK_SHADER_STAGE_RAYGEN_BIT_KHR | VK_SHADER_STAGE_MISS_BIT_KHR |
241                         VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
242 
243     // Command pool and buffer.
244     const auto cmdPool      = makeCommandPool(vkd, device, qIndex);
245     const auto cmdBufferPtr = allocateCommandBuffer(vkd, device, cmdPool.get(), VK_COMMAND_BUFFER_LEVEL_PRIMARY);
246     const auto cmdBuffer    = cmdBufferPtr.get();
247 
248     beginCommandBuffer(vkd, cmdBuffer);
249 
250     // If we add anything to the command buffer here that the AS builds depend on make sure
251     // to submit and wait when in CPU build mode
252 
253     // Build acceleration structures.
254     auto topLevelAS    = makeTopLevelAccelerationStructure();
255     auto bottomLevelAS = makeBottomLevelAccelerationStructure();
256 
257     const std::vector<tcu::Vec3> triangle = {
258         tcu::Vec3(0.0f, 0.0f, 0.0f),
259         tcu::Vec3(1.0f, 0.0f, 0.0f),
260         tcu::Vec3(0.0f, 1.0f, 0.0f),
261     };
262 
263     const VkTransformMatrixKHR notQuiteIdentityMatrix3x4 = {
264         {{0.98f, 0.0f, 0.0f, 0.0f}, {0.0f, 0.97f, 0.0f, 0.0f}, {0.0f, 0.0f, 0.99f, 0.0f}}};
265 
266     de::SharedPtr<RaytracedGeometryBase> geometry =
267         makeRaytracedGeometry(VK_GEOMETRY_TYPE_TRIANGLES_KHR, m_params.vertexFormat, VK_INDEX_TYPE_NONE_KHR);
268 
269     for (auto &v : triangle)
270     {
271         geometry->addVertex(v);
272     }
273 
274     bottomLevelAS->addGeometry(geometry);
275     bottomLevelAS->setBuildFlags(VK_BUILD_ACCELERATION_STRUCTURE_ALLOW_DATA_ACCESS_KHR);
276     bottomLevelAS->setBuildType(m_params.buildType);
277     bottomLevelAS->createAndBuild(vkd, device, cmdBuffer, alloc);
278     de::SharedPtr<BottomLevelAccelerationStructure> blasSharedPtr(bottomLevelAS.release());
279 
280     topLevelAS->setInstanceCount(1);
281     topLevelAS->setBuildType(m_params.buildType);
282     topLevelAS->addInstance(blasSharedPtr, (m_params.testFlagMask & TEST_FLAG_BIT_INSTANCE_TRANSFORM) ?
283                                                notQuiteIdentityMatrix3x4 :
284                                                identityMatrix3x4);
285     topLevelAS->createAndBuild(vkd, device, cmdBuffer, alloc);
286 
287     // One ray for this test
288     // XXX Should it be multiple triangles and one ray per triangle for more coverage?
289     // XXX If it's really one ray, the origin buffer is complete overkill
290     uint32_t numRays = 1; // XXX
291 
292     // SSBO buffer for origins.
293     const auto originsBufferSize = static_cast<VkDeviceSize>(sizeof(tcu::Vec4) * numRays);
294     const auto originsBufferInfo = makeBufferCreateInfo(originsBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
295     BufferWithMemory originsBuffer(vkd, device, alloc, originsBufferInfo, MemoryRequirement::HostVisible);
296     auto &originsBufferAlloc = originsBuffer.getAllocation();
297     void *originsBufferData  = originsBufferAlloc.getHostPtr();
298 
299     std::vector<tcu::Vec4> origins;
300     std::vector<tcu::Vec3> expectedOutputPositions;
301     origins.reserve(numRays);
302     expectedOutputPositions.reserve(6 * numRays);
303 
304     // Fill in vector of expected outputs
305     for (uint32_t index = 0; index < numRays; index++)
306     {
307         for (uint32_t vert = 0; vert < 3; vert++)
308         {
309             tcu::Vec3 pos = triangle[vert];
310 
311             // One from CH, one from AH
312             expectedOutputPositions.push_back(pos);
313             expectedOutputPositions.push_back(pos);
314         }
315     }
316 
317     // XXX Arbitrary location and see above
318     for (uint32_t index = 0; index < numRays; index++)
319     {
320         origins.push_back(tcu::Vec4(0.25, 0.25, 1.0, 0.0));
321     }
322 
323     const auto originsBufferSizeSz = static_cast<size_t>(originsBufferSize);
324     deMemcpy(originsBufferData, origins.data(), originsBufferSizeSz);
325     flushAlloc(vkd, device, originsBufferAlloc);
326 
327     // Storage buffer for output modes
328     const auto outputPositionsBufferSize = static_cast<VkDeviceSize>(6 * 4 * sizeof(float) * numRays);
329     const auto outputPositionsBufferInfo =
330         makeBufferCreateInfo(outputPositionsBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
331     BufferWithMemory outputPositionsBuffer(vkd, device, alloc, outputPositionsBufferInfo,
332                                            MemoryRequirement::HostVisible);
333     auto &outputPositionsBufferAlloc = outputPositionsBuffer.getAllocation();
334     void *outputPositionsBufferData  = outputPositionsBufferAlloc.getHostPtr();
335     deMemset(outputPositionsBufferData, 0xFF, static_cast<size_t>(outputPositionsBufferSize));
336     flushAlloc(vkd, device, outputPositionsBufferAlloc);
337 
338     // Descriptor set layout.
339     DescriptorSetLayoutBuilder dsLayoutBuilder;
340     dsLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, stages);
341     dsLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, stages);
342     dsLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, stages);
343     const auto setLayout = dsLayoutBuilder.build(vkd, device);
344 
345     // Pipeline layout.
346     const auto pipelineLayout = makePipelineLayout(vkd, device, setLayout.get());
347 
348     // Descriptor pool and set.
349     DescriptorPoolBuilder poolBuilder;
350     poolBuilder.addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR);
351     poolBuilder.addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
352     poolBuilder.addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
353     const auto descriptorPool = poolBuilder.build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
354     const auto descriptorSet  = makeDescriptorSet(vkd, device, descriptorPool.get(), setLayout.get());
355 
356     // Update descriptor set.
357     {
358         const VkWriteDescriptorSetAccelerationStructureKHR accelDescInfo = {
359             VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR,
360             nullptr,
361             1u,
362             topLevelAS.get()->getPtr(),
363         };
364         const auto inStorageBufferInfo = makeDescriptorBufferInfo(originsBuffer.get(), 0ull, VK_WHOLE_SIZE);
365         const auto storageBufferInfo   = makeDescriptorBufferInfo(outputPositionsBuffer.get(), 0ull, VK_WHOLE_SIZE);
366 
367         DescriptorSetUpdateBuilder updateBuilder;
368         updateBuilder.writeSingle(descriptorSet.get(), DescriptorSetUpdateBuilder::Location::binding(0u),
369                                   VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &accelDescInfo);
370         updateBuilder.writeSingle(descriptorSet.get(), DescriptorSetUpdateBuilder::Location::binding(1u),
371                                   VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &inStorageBufferInfo);
372         updateBuilder.writeSingle(descriptorSet.get(), DescriptorSetUpdateBuilder::Location::binding(2u),
373                                   VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &storageBufferInfo);
374         updateBuilder.update(vkd, device);
375     }
376 
377     // Shader modules.
378     auto rgenModule = createShaderModule(vkd, device, m_context.getBinaryCollection().get("rgen"), 0);
379     auto missModule = createShaderModule(vkd, device, m_context.getBinaryCollection().get("miss"), 0);
380     auto ahModule   = createShaderModule(vkd, device, m_context.getBinaryCollection().get("ah"), 0);
381     auto chModule   = createShaderModule(vkd, device, m_context.getBinaryCollection().get("ch"), 0);
382 
383     // Get some ray tracing properties.
384     uint32_t shaderGroupHandleSize    = 0u;
385     uint32_t shaderGroupBaseAlignment = 1u;
386     {
387         const auto rayTracingPropertiesKHR = makeRayTracingProperties(vki, physDev);
388         shaderGroupHandleSize              = rayTracingPropertiesKHR->getShaderGroupHandleSize();
389         shaderGroupBaseAlignment           = rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
390     }
391 
392     // Create raytracing pipeline and shader binding tables.
393     Move<VkPipeline> pipeline;
394     de::MovePtr<BufferWithMemory> raygenSBT;
395     de::MovePtr<BufferWithMemory> missSBT;
396     de::MovePtr<BufferWithMemory> hitSBT;
397     de::MovePtr<BufferWithMemory> callableSBT;
398 
399     auto raygenSBTRegion   = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
400     auto missSBTRegion     = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
401     auto hitSBTRegion      = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
402     auto callableSBTRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
403 
404     {
405         const auto rayTracingPipeline = de::newMovePtr<RayTracingPipeline>();
406         rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR, rgenModule, 0);
407         rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR, missModule, 1);
408         rayTracingPipeline->addShader(VK_SHADER_STAGE_ANY_HIT_BIT_KHR, ahModule, 2);
409         rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR, chModule, 2);
410 
411         pipeline = rayTracingPipeline->createPipeline(vkd, device, pipelineLayout.get());
412 
413         raygenSBT       = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc,
414                                                                        shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
415         raygenSBTRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenSBT->get(), 0),
416                                                             shaderGroupHandleSize, shaderGroupHandleSize);
417 
418         missSBT       = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc,
419                                                                      shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
420         missSBTRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missSBT->get(), 0),
421                                                           shaderGroupHandleSize, shaderGroupHandleSize);
422 
423         hitSBT = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc, shaderGroupHandleSize,
424                                                               shaderGroupBaseAlignment, 2, 1);
425         hitSBTRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitSBT->get(), 0),
426                                                          shaderGroupHandleSize, shaderGroupHandleSize);
427     }
428 
429     // Trace rays.
430     vkd.cmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipeline.get());
431     vkd.cmdBindDescriptorSets(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipelineLayout.get(), 0u, 1u,
432                               &descriptorSet.get(), 0u, nullptr);
433     vkd.cmdTraceRaysKHR(cmdBuffer, &raygenSBTRegion, &missSBTRegion, &hitSBTRegion, &callableSBTRegion, numRays, 1u,
434                         1u);
435 
436     // Barrier for the output buffer.
437     const auto bufferBarrier = makeMemoryBarrier(VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
438     vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR, VK_PIPELINE_STAGE_HOST_BIT, 0u, 1u,
439                            &bufferBarrier, 0u, nullptr, 0u, nullptr);
440 
441     endCommandBuffer(vkd, cmdBuffer);
442     submitCommandsAndWait(vkd, device, queue, cmdBuffer);
443 
444     // Verify results.
445     std::vector<tcu::Vec4> outputData(expectedOutputPositions.size());
446     const auto outputPositionsBufferSizeSz = static_cast<size_t>(outputPositionsBufferSize);
447 
448     invalidateAlloc(vkd, device, outputPositionsBufferAlloc);
449     DE_ASSERT(de::dataSize(outputData) == outputPositionsBufferSizeSz);
450     deMemcpy(outputData.data(), outputPositionsBufferData, outputPositionsBufferSizeSz);
451 
452     for (size_t i = 0; i < outputData.size(); ++i)
453     {
454         /*const */ auto &outVal = outputData[i]; // Should be const but .xyz() isn't
455         tcu::Vec3 outVec3       = outVal.xyz();
456         const auto &expectedVal = expectedOutputPositions[i];
457         const auto &diff        = expectedOutputPositions[i] - outVec3;
458         float len               = dot(diff, diff);
459 
460         // XXX Find a better epsilon
461         if (!(len < 1e-5))
462         {
463             std::ostringstream msg;
464             msg << "Unexpected value found for element " << i << ": expected " << expectedVal << " and found " << outVal
465                 << ";";
466             TCU_FAIL(msg.str());
467         }
468 #if 0
469         else
470         {
471             std::ostringstream msg;
472             msg << "Expected value found for element " << i << ": expected " << expectedVal << " and found " << outVal << ";\n";
473             std::cout << msg.str();
474         }
475 #endif
476     }
477 
478     return tcu::TestStatus::pass("Pass");
479 }
480 
481 } // namespace
482 
createPositionFetchTests(tcu::TestContext & testCtx)483 tcu::TestCaseGroup *createPositionFetchTests(tcu::TestContext &testCtx)
484 {
485     // Test ray pipeline shaders using position fetch
486     de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "position_fetch"));
487 
488     struct
489     {
490         vk::VkAccelerationStructureBuildTypeKHR buildType;
491         const char *name;
492     } buildTypes[] = {
493         {VK_ACCELERATION_STRUCTURE_BUILD_TYPE_HOST_KHR, "cpu_built"},
494         {VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR, "gpu_built"},
495     };
496 
497     const VkFormat vertexFormats[] = {
498         // Mandatory formats.
499         VK_FORMAT_R32G32_SFLOAT,
500         VK_FORMAT_R32G32B32_SFLOAT,
501         VK_FORMAT_R16G16_SFLOAT,
502         VK_FORMAT_R16G16B16A16_SFLOAT,
503         VK_FORMAT_R16G16_SNORM,
504         VK_FORMAT_R16G16B16A16_SNORM,
505 
506         // Additional formats.
507         VK_FORMAT_R8G8_SNORM,
508         VK_FORMAT_R8G8B8_SNORM,
509         VK_FORMAT_R8G8B8A8_SNORM,
510         VK_FORMAT_R16G16B16_SNORM,
511         VK_FORMAT_R16G16B16_SFLOAT,
512         VK_FORMAT_R32G32B32A32_SFLOAT,
513         VK_FORMAT_R64G64_SFLOAT,
514         VK_FORMAT_R64G64B64_SFLOAT,
515         VK_FORMAT_R64G64B64A64_SFLOAT,
516     };
517 
518     for (size_t buildTypeNdx = 0; buildTypeNdx < DE_LENGTH_OF_ARRAY(buildTypes); ++buildTypeNdx)
519     {
520         de::MovePtr<tcu::TestCaseGroup> buildGroup(
521             new tcu::TestCaseGroup(group->getTestContext(), buildTypes[buildTypeNdx].name));
522 
523         for (size_t vertexFormatNdx = 0; vertexFormatNdx < DE_LENGTH_OF_ARRAY(vertexFormats); ++vertexFormatNdx)
524         {
525             const auto format     = vertexFormats[vertexFormatNdx];
526             const auto formatName = getFormatSimpleName(format);
527 
528             de::MovePtr<tcu::TestCaseGroup> vertexFormatGroup(
529                 new tcu::TestCaseGroup(group->getTestContext(), formatName.c_str()));
530 
531             for (uint32_t testFlagMask = 0; testFlagMask < TEST_FLAG_BIT_LAST; testFlagMask++)
532             {
533                 std::string maskName = "";
534 
535                 for (uint32_t bit = 0; bit < testFlagBitNames.size(); bit++)
536                 {
537                     if (testFlagMask & (1 << bit))
538                     {
539                         if (maskName != "")
540                             maskName += "_";
541                         maskName += testFlagBitNames[bit];
542                     }
543                 }
544                 if (maskName == "")
545                     maskName = "NoFlags";
546 
547                 de::MovePtr<tcu::TestCaseGroup> testFlagGroup(
548                     new tcu::TestCaseGroup(group->getTestContext(), maskName.c_str()));
549 
550                 TestParams testParams{
551                     buildTypes[buildTypeNdx].buildType,
552                     format,
553                     testFlagMask,
554                 };
555 
556                 vertexFormatGroup->addChild(new PositionFetchCase(testCtx, maskName, testParams));
557             }
558             buildGroup->addChild(vertexFormatGroup.release());
559         }
560         group->addChild(buildGroup.release());
561     }
562 
563     return group.release();
564 }
565 
566 } // namespace RayTracing
567 } // namespace vkt
568