1 /*-------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2021 The Khronos Group Inc.
6  * Copyright (c) 2021 Valve Corporation.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  *//*!
21  * \file
22  * \brief Ray Tracing Barycentric Coordinates Tests
23  *//*--------------------------------------------------------------------*/
24 
25 #include "vktRayTracingBarycentricCoordinatesTests.hpp"
26 #include "vktTestCase.hpp"
27 
28 #include "vkRayTracingUtil.hpp"
29 #include "vkObjUtil.hpp"
30 #include "vkCmdUtil.hpp"
31 #include "vkBufferWithMemory.hpp"
32 #include "vkBuilderUtil.hpp"
33 #include "vkTypeUtil.hpp"
34 #include "vkBarrierUtil.hpp"
35 
36 #include "deUniquePtr.hpp"
37 #include "deRandom.hpp"
38 
39 #include <sstream>
40 #include <vector>
41 
42 namespace vkt
43 {
44 namespace RayTracing
45 {
46 
47 namespace
48 {
49 
50 using namespace vk;
51 
52 enum class TestCaseRT
53 {
54     CLOSEST_HIT,
55     ANY_HIT,
56     CLOSEST_AND_ANY_HIT_TERMINATE
57 };
58 
59 struct TestParams
60 {
61     TestCaseRT testCase;
62     uint32_t seed;
63 };
64 
getUsedStages(const TestParams & params)65 VkShaderStageFlags getUsedStages(const TestParams &params)
66 {
67     VkShaderStageFlags stageFlags{VK_SHADER_STAGE_RAYGEN_BIT_KHR | VK_SHADER_STAGE_MISS_BIT_KHR};
68     if (params.testCase == TestCaseRT::CLOSEST_HIT)
69         stageFlags |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
70     else if (params.testCase == TestCaseRT::ANY_HIT)
71         stageFlags |= VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
72     else if (params.testCase == TestCaseRT::CLOSEST_AND_ANY_HIT_TERMINATE)
73         stageFlags |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
74 
75     return stageFlags;
76 }
77 
78 constexpr float kZCoord     = 5.0f;
79 constexpr float kXYCoordAbs = 1.0f;
80 
81 constexpr float kThreshold  = 0.001f;            // For the resulting coordinates.
82 constexpr float kTMin       = 1.0f - kThreshold; // Require the same precision in T.
83 constexpr float kTMax       = 1.0f + kThreshold; // Ditto.
84 constexpr uint32_t kNumRays = 20u;
85 
86 class BarycentricCoordinatesCase : public TestCase
87 {
88 public:
89     BarycentricCoordinatesCase(tcu::TestContext &testCtx, const std::string &name, const TestParams &params);
~BarycentricCoordinatesCase(void)90     virtual ~BarycentricCoordinatesCase(void)
91     {
92     }
93 
94     virtual void checkSupport(Context &context) const;
95     virtual void initPrograms(vk::SourceCollections &programCollection) const;
96     virtual TestInstance *createInstance(Context &context) const;
97 
98 protected:
99     TestParams m_params;
100 };
101 
102 class BarycentricCoordinatesInstance : public TestInstance
103 {
104 public:
105     BarycentricCoordinatesInstance(Context &context, const TestParams &params);
~BarycentricCoordinatesInstance(void)106     virtual ~BarycentricCoordinatesInstance(void)
107     {
108     }
109 
110     virtual tcu::TestStatus iterate(void);
111 
112 protected:
113     TestParams m_params;
114 };
115 
BarycentricCoordinatesCase(tcu::TestContext & testCtx,const std::string & name,const TestParams & params)116 BarycentricCoordinatesCase::BarycentricCoordinatesCase(tcu::TestContext &testCtx, const std::string &name,
117                                                        const TestParams &params)
118     : TestCase(testCtx, name)
119     , m_params(params)
120 {
121 }
122 
checkSupport(Context & context) const123 void BarycentricCoordinatesCase::checkSupport(Context &context) const
124 {
125     context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
126     context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
127 }
128 
initPrograms(vk::SourceCollections & programCollection) const129 void BarycentricCoordinatesCase::initPrograms(vk::SourceCollections &programCollection) const
130 {
131     const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
132 
133     std::ostringstream layoutDecls;
134     layoutDecls << "layout(set=0, binding=0) uniform accelerationStructureEXT topLevelAS;\n"
135                 << "layout(set=0, binding=1) uniform RayDirections {\n"
136                 << "  vec4 values[" << kNumRays << "];\n"
137                 << "} directions;\n"
138                 << "layout(set=0, binding=2, std430) buffer OutputBarycentrics {\n"
139                 << "  vec4 values[" << kNumRays << "];\n"
140                 << "} coordinates;\n";
141     const auto layoutDeclsStr = layoutDecls.str();
142 
143     std::ostringstream rgen;
144     rgen << "#version 460 core\n"
145          << "#extension GL_EXT_ray_tracing : require\n"
146          << "\n"
147          << "layout(location=0) rayPayloadEXT vec3 hitValue;\n"
148          << "\n"
149          << layoutDeclsStr << "\n"
150          << "void main()\n"
151          << "{\n"
152          << "  const uint  cullMask  = 0xFF;\n"
153          << "  const vec3  origin    = vec3(0.0, 0.0, 0.0);\n"
154          << "  const vec3  direction = directions.values[gl_LaunchIDEXT.x].xyz;\n"
155          << "  const float tMin      = " << kTMin << ";\n"
156          << "  const float tMax      = " << kTMax << ";\n"
157          << "  traceRayEXT(topLevelAS, gl_RayFlagsNoneEXT, cullMask, 0, 0, 0, origin, tMin, direction, tMax, 0);\n"
158          << "}\n";
159 
160     std::ostringstream chit;
161     chit << "#version 460 core\n"
162          << "#extension GL_EXT_ray_tracing : require\n"
163          << "\n"
164          << "hitAttributeEXT vec2 baryCoord;\n"
165          << "\n"
166          << layoutDeclsStr << "\n"
167          << "void main()\n"
168          << "{\n"
169          << "  coordinates.values[gl_LaunchIDEXT.x].xy = baryCoord;\n"
170          << "}\n";
171 
172     std::ostringstream ahitTerminate;
173     ahitTerminate << "#version 460 core\n"
174                   << "#extension GL_EXT_ray_tracing : require\n"
175                   << "\n"
176                   << "hitAttributeEXT vec2 baryCoord;\n"
177                   << "\n"
178                   << layoutDeclsStr << "\n"
179                   << "void main()\n"
180                   << "{\n"
181                   << "  coordinates.values[gl_LaunchIDEXT.x].z = 0.999;\n"
182                   << "  if(baryCoord.x < 0.7){\n"
183                   << "    terminateRayEXT;\n"
184                   << "    coordinates.values[gl_LaunchIDEXT.x].z = 0.5;\n"
185                   << "  }\n"
186                   << "}\n";
187 
188     std::ostringstream miss;
189     miss << "#version 460 core\n"
190          << "#extension GL_EXT_ray_tracing : require\n"
191          << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
192          << layoutDeclsStr << "\n"
193          << "void main()\n"
194          << "{\n"
195          << "  coordinates.values[gl_LaunchIDEXT.x] = vec4(-1.0, -1.0, -1.0, -1.0);\n"
196          << "}\n";
197 
198     programCollection.glslSources.add("rgen") << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
199     programCollection.glslSources.add("miss") << glu::MissSource(updateRayTracingGLSL(miss.str())) << buildOptions;
200 
201     if (m_params.testCase == TestCaseRT::CLOSEST_HIT)
202         programCollection.glslSources.add("chit")
203             << glu::ClosestHitSource(updateRayTracingGLSL(chit.str())) << buildOptions;
204     else if (m_params.testCase == TestCaseRT::ANY_HIT)
205         programCollection.glslSources.add("chit")
206             << glu::AnyHitSource(updateRayTracingGLSL(chit.str())) << buildOptions;
207     else if (m_params.testCase == TestCaseRT::CLOSEST_AND_ANY_HIT_TERMINATE)
208     {
209         programCollection.glslSources.add("chit")
210             << glu::ClosestHitSource(updateRayTracingGLSL(chit.str())) << buildOptions;
211         programCollection.glslSources.add("ahitTerminate")
212             << glu::AnyHitSource(updateRayTracingGLSL(ahitTerminate.str())) << buildOptions;
213     }
214     else
215         DE_ASSERT(false);
216 }
217 
createInstance(Context & context) const218 TestInstance *BarycentricCoordinatesCase::createInstance(Context &context) const
219 {
220     return new BarycentricCoordinatesInstance(context, m_params);
221 }
222 
BarycentricCoordinatesInstance(Context & context,const TestParams & params)223 BarycentricCoordinatesInstance::BarycentricCoordinatesInstance(Context &context, const TestParams &params)
224     : TestInstance(context)
225     , m_params(params)
226 {
227 }
228 
229 // Calculates coordinates in a triangle given barycentric coordinates b and c.
calcCoordinates(const std::vector<tcu::Vec3> & triangle,float b,float c)230 tcu::Vec3 calcCoordinates(const std::vector<tcu::Vec3> &triangle, float b, float c)
231 {
232     DE_ASSERT(triangle.size() == 3u);
233     DE_ASSERT(b > 0.0f);
234     DE_ASSERT(c > 0.0f);
235     DE_ASSERT(b + c < 1.0f);
236 
237     const float a = 1.0f - b - c;
238     DE_ASSERT(a > 0.0f);
239 
240     return triangle[0] * a + triangle[1] * b + triangle[2] * c;
241 }
242 
243 // Return a, b, c with a close to 1.0f and (b, c) close to 0.0f.
getBarycentricVertex(void)244 tcu::Vec3 getBarycentricVertex(void)
245 {
246     const float a   = 0.999f;
247     const float aux = 1.0f - a;
248     const float b   = aux / 2.0f;
249     const float c   = b;
250 
251     return tcu::Vec3(a, b, c);
252 }
253 
extendToV4(const tcu::Vec3 & vec3)254 tcu::Vec4 extendToV4(const tcu::Vec3 &vec3)
255 {
256     return tcu::Vec4(vec3.x(), vec3.y(), vec3.z(), 0.0f);
257 }
258 
iterate(void)259 tcu::TestStatus BarycentricCoordinatesInstance::iterate(void)
260 {
261     const auto &vki    = m_context.getInstanceInterface();
262     const auto physDev = m_context.getPhysicalDevice();
263     const auto &vkd    = m_context.getDeviceInterface();
264     const auto device  = m_context.getDevice();
265     auto &alloc        = m_context.getDefaultAllocator();
266     const auto qIndex  = m_context.getUniversalQueueFamilyIndex();
267     const auto queue   = m_context.getUniversalQueue();
268     const auto stages  = getUsedStages(m_params);
269 
270     // Command pool and buffer.
271     const auto cmdPool      = makeCommandPool(vkd, device, qIndex);
272     const auto cmdBufferPtr = allocateCommandBuffer(vkd, device, cmdPool.get(), VK_COMMAND_BUFFER_LEVEL_PRIMARY);
273     const auto cmdBuffer    = cmdBufferPtr.get();
274 
275     beginCommandBuffer(vkd, cmdBuffer);
276 
277     // Build acceleration structures.
278     auto topLevelAS    = makeTopLevelAccelerationStructure();
279     auto bottomLevelAS = makeBottomLevelAccelerationStructure();
280 
281     const std::vector<tcu::Vec3> triangle = {
282         tcu::Vec3(0.0f, -kXYCoordAbs, kZCoord),
283         tcu::Vec3(-kXYCoordAbs, kXYCoordAbs, kZCoord),
284         tcu::Vec3(kXYCoordAbs, kXYCoordAbs, kZCoord),
285     };
286 
287     bottomLevelAS->addGeometry(triangle, true /*is triangles*/, VK_GEOMETRY_NO_DUPLICATE_ANY_HIT_INVOCATION_BIT_KHR);
288     bottomLevelAS->createAndBuild(vkd, device, cmdBuffer, alloc);
289     de::SharedPtr<BottomLevelAccelerationStructure> blasSharedPtr(bottomLevelAS.release());
290 
291     topLevelAS->setInstanceCount(1);
292     topLevelAS->addInstance(blasSharedPtr, identityMatrix3x4, 0, 0xFFu, 0u,
293                             VK_GEOMETRY_INSTANCE_TRIANGLE_FACING_CULL_DISABLE_BIT_KHR);
294     topLevelAS->createAndBuild(vkd, device, cmdBuffer, alloc);
295 
296     // Uniform buffer for directions.
297     const auto directionsBufferSize = static_cast<VkDeviceSize>(sizeof(tcu::Vec4) * kNumRays);
298     const auto directionsBufferInfo = makeBufferCreateInfo(directionsBufferSize, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT);
299     BufferWithMemory directionsBuffer(vkd, device, alloc, directionsBufferInfo, MemoryRequirement::HostVisible);
300     auto &directionsBufferAlloc = directionsBuffer.getAllocation();
301     void *directionsBufferData  = directionsBufferAlloc.getHostPtr();
302 
303     // Generate rays towards the 3 triangle coordinates (avoiding exact vertices) and additional coordinates.
304     std::vector<tcu::Vec4> directions;
305     std::vector<tcu::Vec4> expectedOutputCoordinates;
306     directions.reserve(kNumRays);
307     expectedOutputCoordinates.reserve(kNumRays);
308 
309     const auto barycentricABC = getBarycentricVertex();
310 
311     directions.push_back(extendToV4(calcCoordinates(triangle, barycentricABC.x(), barycentricABC.y())));
312     directions.push_back(extendToV4(calcCoordinates(triangle, barycentricABC.y(), barycentricABC.x())));
313     directions.push_back(extendToV4(calcCoordinates(triangle, barycentricABC.y(), barycentricABC.z())));
314 
315     float expectedZ = 0.0f;
316     // Set expectedZ to the same value as the AnyHit shader sets
317     if (m_params.testCase == TestCaseRT::CLOSEST_AND_ANY_HIT_TERMINATE)
318         expectedZ = 0.999f;
319 
320     expectedOutputCoordinates.push_back(tcu::Vec4(barycentricABC.x(), barycentricABC.y(), expectedZ, 0.0f));
321     expectedOutputCoordinates.push_back(tcu::Vec4(barycentricABC.y(), barycentricABC.x(), expectedZ, 0.0f));
322     expectedOutputCoordinates.push_back(tcu::Vec4(barycentricABC.y(), barycentricABC.z(), expectedZ, 0.0f));
323 
324     de::Random rnd(m_params.seed);
325     while (directions.size() < kNumRays)
326     {
327         // Avoid 0.0 when choosing b and c.
328         float b;
329         while ((b = rnd.getFloat()) == 0.0f)
330             ;
331         float c;
332         while ((c = rnd.getFloat(0.0f, 1.0f - b)) == 0.0f)
333             ;
334         directions.push_back(extendToV4(calcCoordinates(triangle, b, c)));
335         if (m_params.testCase == TestCaseRT::CLOSEST_AND_ANY_HIT_TERMINATE)
336             expectedOutputCoordinates.push_back(tcu::Vec4(b, c, expectedZ, 0.0f));
337         else
338             expectedOutputCoordinates.push_back(tcu::Vec4(b, c, 0.0f, 0.0f));
339     }
340 
341     deMemcpy(directionsBufferData, directions.data(), directionsBufferSize);
342     flushAlloc(vkd, device, directionsBufferAlloc);
343 
344     // Storage buffer for output barycentric coordinates.
345     const auto barycoordsBufferSize = static_cast<VkDeviceSize>(sizeof(tcu::Vec4) * kNumRays);
346     const auto barycoordsBufferInfo = makeBufferCreateInfo(barycoordsBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
347     BufferWithMemory barycoordsBuffer(vkd, device, alloc, barycoordsBufferInfo, MemoryRequirement::HostVisible);
348     auto &barycoordsBufferAlloc = barycoordsBuffer.getAllocation();
349     void *barycoordsBufferData  = barycoordsBufferAlloc.getHostPtr();
350     deMemset(barycoordsBufferData, 0, static_cast<size_t>(barycoordsBufferSize));
351     flushAlloc(vkd, device, barycoordsBufferAlloc);
352 
353     // Descriptor set layout.
354     DescriptorSetLayoutBuilder dsLayoutBuilder;
355     dsLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, stages);
356     dsLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, stages);
357     dsLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, stages);
358     const auto setLayout = dsLayoutBuilder.build(vkd, device);
359 
360     // Pipeline layout.
361     const auto pipelineLayout = makePipelineLayout(vkd, device, setLayout.get());
362 
363     // Descriptor pool and set.
364     DescriptorPoolBuilder poolBuilder;
365     poolBuilder.addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR);
366     poolBuilder.addType(VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER);
367     poolBuilder.addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
368     const auto descriptorPool = poolBuilder.build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
369     const auto descriptorSet  = makeDescriptorSet(vkd, device, descriptorPool.get(), setLayout.get());
370 
371     // Update descriptor set.
372     {
373         const VkWriteDescriptorSetAccelerationStructureKHR accelDescInfo = {
374             VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR,
375             nullptr,
376             1u,
377             topLevelAS.get()->getPtr(),
378         };
379         const auto uniformBufferInfo = makeDescriptorBufferInfo(directionsBuffer.get(), 0ull, VK_WHOLE_SIZE);
380         const auto storageBufferInfo = makeDescriptorBufferInfo(barycoordsBuffer.get(), 0ull, VK_WHOLE_SIZE);
381 
382         DescriptorSetUpdateBuilder updateBuilder;
383         updateBuilder.writeSingle(descriptorSet.get(), DescriptorSetUpdateBuilder::Location::binding(0u),
384                                   VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &accelDescInfo);
385         updateBuilder.writeSingle(descriptorSet.get(), DescriptorSetUpdateBuilder::Location::binding(1u),
386                                   VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, &uniformBufferInfo);
387         updateBuilder.writeSingle(descriptorSet.get(), DescriptorSetUpdateBuilder::Location::binding(2u),
388                                   VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &storageBufferInfo);
389         updateBuilder.update(vkd, device);
390     }
391 
392     // Shader modules.
393     auto rgenModule = createShaderModule(vkd, device, m_context.getBinaryCollection().get("rgen"), 0);
394     auto missModule = createShaderModule(vkd, device, m_context.getBinaryCollection().get("miss"), 0);
395     auto chitModule = createShaderModule(vkd, device, m_context.getBinaryCollection().get("chit"), 0);
396     Move<VkShaderModule> ahitTerminateModule;
397     if (m_params.testCase == TestCaseRT::CLOSEST_AND_ANY_HIT_TERMINATE)
398         ahitTerminateModule = createShaderModule(vkd, device, m_context.getBinaryCollection().get("ahitTerminate"), 0);
399 
400     // Get some ray tracing properties.
401     uint32_t shaderGroupHandleSize    = 0u;
402     uint32_t shaderGroupBaseAlignment = 1u;
403     {
404         const auto rayTracingPropertiesKHR = makeRayTracingProperties(vki, physDev);
405         shaderGroupHandleSize              = rayTracingPropertiesKHR->getShaderGroupHandleSize();
406         shaderGroupBaseAlignment           = rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
407     }
408 
409     // Create raytracing pipeline and shader binding tables.
410     Move<VkPipeline> pipeline;
411     de::MovePtr<BufferWithMemory> raygenSBT;
412     de::MovePtr<BufferWithMemory> missSBT;
413     de::MovePtr<BufferWithMemory> hitSBT;
414     de::MovePtr<BufferWithMemory> callableSBT;
415 
416     auto raygenSBTRegion   = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
417     auto missSBTRegion     = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
418     auto hitSBTRegion      = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
419     auto callableSBTRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
420 
421     {
422         const auto rayTracingPipeline = de::newMovePtr<RayTracingPipeline>();
423         rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR, rgenModule, 0);
424         rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR, missModule, 1);
425         if (m_params.testCase == TestCaseRT::CLOSEST_HIT)
426             rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR, chitModule, 2);
427         if (m_params.testCase == TestCaseRT::ANY_HIT)
428             rayTracingPipeline->addShader(VK_SHADER_STAGE_ANY_HIT_BIT_KHR, chitModule, 2);
429         else if (m_params.testCase == TestCaseRT::CLOSEST_AND_ANY_HIT_TERMINATE)
430         {
431             rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR, chitModule, 2);
432             rayTracingPipeline->addShader(VK_SHADER_STAGE_ANY_HIT_BIT_KHR, ahitTerminateModule, 2);
433         }
434 
435         pipeline = rayTracingPipeline->createPipeline(vkd, device, pipelineLayout.get());
436 
437         raygenSBT       = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc,
438                                                                        shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
439         raygenSBTRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenSBT->get(), 0),
440                                                             shaderGroupHandleSize, shaderGroupHandleSize);
441 
442         missSBT       = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc,
443                                                                      shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
444         missSBTRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missSBT->get(), 0),
445                                                           shaderGroupHandleSize, shaderGroupHandleSize);
446 
447         hitSBT = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc, shaderGroupHandleSize,
448                                                               shaderGroupBaseAlignment, 2, 1);
449         hitSBTRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitSBT->get(), 0),
450                                                          shaderGroupHandleSize, shaderGroupHandleSize);
451     }
452 
453     // Trace rays.
454     vkd.cmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipeline.get());
455     vkd.cmdBindDescriptorSets(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipelineLayout.get(), 0u, 1u,
456                               &descriptorSet.get(), 0u, nullptr);
457     vkd.cmdTraceRaysKHR(cmdBuffer, &raygenSBTRegion, &missSBTRegion, &hitSBTRegion, &callableSBTRegion, kNumRays, 1u,
458                         1u);
459 
460     // Barrier for the output buffer.
461     const auto bufferBarrier = makeMemoryBarrier(VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
462     vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR, VK_PIPELINE_STAGE_HOST_BIT, 0u, 1u,
463                            &bufferBarrier, 0u, nullptr, 0u, nullptr);
464 
465     endCommandBuffer(vkd, cmdBuffer);
466     submitCommandsAndWait(vkd, device, queue, cmdBuffer);
467 
468     // Verify results.
469     std::vector<tcu::Vec4> outputData(expectedOutputCoordinates.size());
470     const auto barycoordsBufferSizeSz = static_cast<size_t>(barycoordsBufferSize);
471 
472     invalidateAlloc(vkd, device, barycoordsBufferAlloc);
473     DE_ASSERT(de::dataSize(outputData) == barycoordsBufferSizeSz);
474     deMemcpy(outputData.data(), barycoordsBufferData, barycoordsBufferSizeSz);
475 
476     for (size_t i = 0; i < outputData.size(); ++i)
477     {
478         const auto &outVal      = outputData[i];
479         const auto &expectedVal = expectedOutputCoordinates[i];
480 
481         if (outVal.z() != expectedVal.z() || outVal.w() != 0.0f || de::abs(outVal.x() - expectedVal.x()) > kThreshold ||
482             de::abs(outVal.y() - expectedVal.y()) > kThreshold)
483         {
484             std::ostringstream msg;
485             msg << "Unexpected value found for ray " << i << ": expected " << expectedVal << " and found " << outVal
486                 << ";";
487             TCU_FAIL(msg.str());
488         }
489     }
490     return tcu::TestStatus::pass("Pass");
491 }
492 
493 } // namespace
494 
createBarycentricCoordinatesTests(tcu::TestContext & testCtx)495 tcu::TestCaseGroup *createBarycentricCoordinatesTests(tcu::TestContext &testCtx)
496 {
497     using GroupPtr = de::MovePtr<tcu::TestCaseGroup>;
498 
499     // Test barycentric coordinates reported in hit attributes
500     GroupPtr mainGroup(new tcu::TestCaseGroup(testCtx, "barycentric_coordinates"));
501 
502     uint32_t seed = 1614343620u;
503     mainGroup->addChild(new BarycentricCoordinatesCase(testCtx, "chit", TestParams{TestCaseRT::CLOSEST_HIT, seed++}));
504     mainGroup->addChild(new BarycentricCoordinatesCase(testCtx, "ahit", TestParams{TestCaseRT::ANY_HIT, seed++}));
505     mainGroup->addChild(new BarycentricCoordinatesCase(testCtx, "ahitTerminate",
506                                                        TestParams{TestCaseRT::CLOSEST_AND_ANY_HIT_TERMINATE, seed++}));
507 
508     return mainGroup.release();
509 }
510 
511 } // namespace RayTracing
512 } // namespace vkt
513