1 /*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2021 The Khronos Group Inc.
6 * Copyright (c) 2021 Valve Corporation.
7 *
8 * Licensed under the Apache License, Version 2.0 (the "License");
9 * you may not use this file except in compliance with the License.
10 * You may obtain a copy of the License at
11 *
12 * http://www.apache.org/licenses/LICENSE-2.0
13 *
14 * Unless required by applicable law or agreed to in writing, software
15 * distributed under the License is distributed on an "AS IS" BASIS,
16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 * See the License for the specific language governing permissions and
18 * limitations under the License.
19 *
20 *//*!
21 * \file
22 * \brief Tests using non-uniform arguments with traceRayExt().
23 *//*--------------------------------------------------------------------*/
24
25 #include "vktRayTracingNonUniformArgsTests.hpp"
26 #include "vktTestCase.hpp"
27
28 #include "vkRayTracingUtil.hpp"
29 #include "vkObjUtil.hpp"
30 #include "vkCmdUtil.hpp"
31 #include "vkBuilderUtil.hpp"
32 #include "vkTypeUtil.hpp"
33 #include "vkBarrierUtil.hpp"
34
35 #include "tcuTestLog.hpp"
36
37 #include <vector>
38 #include <iostream>
39
40 namespace vkt
41 {
42 namespace RayTracing
43 {
44 namespace
45 {
46
47 using namespace vk;
48
49 // Causes for hitting the miss shader due to argument values.
50 enum class MissCause
51 {
52 NONE = 0,
53 FLAGS,
54 CULL_MASK,
55 ORIGIN,
56 TMIN,
57 DIRECTION,
58 TMAX,
59 CAUSE_COUNT,
60 };
61
62 struct NonUniformParams
63 {
64 bool miss;
65
66 struct
67 {
68 uint32_t rayTypeCount;
69 uint32_t rayType;
70 } hitParams;
71
72 struct
73 {
74 MissCause missCause;
75 uint32_t missIndex;
76 } missParams;
77 };
78
79 class NonUniformArgsCase : public TestCase
80 {
81 public:
82 NonUniformArgsCase(tcu::TestContext &testCtx, const std::string &name, const NonUniformParams ¶ms);
~NonUniformArgsCase(void)83 virtual ~NonUniformArgsCase(void)
84 {
85 }
86
87 virtual void checkSupport(Context &context) const;
88 virtual void initPrograms(vk::SourceCollections &programCollection) const;
89 virtual TestInstance *createInstance(Context &context) const;
90
91 protected:
92 NonUniformParams m_params;
93 };
94
95 class NonUniformArgsInstance : public TestInstance
96 {
97 public:
98 NonUniformArgsInstance(Context &context, const NonUniformParams ¶ms);
~NonUniformArgsInstance(void)99 virtual ~NonUniformArgsInstance(void)
100 {
101 }
102
103 virtual tcu::TestStatus iterate(void);
104
105 protected:
106 NonUniformParams m_params;
107 };
108
NonUniformArgsCase(tcu::TestContext & testCtx,const std::string & name,const NonUniformParams & params)109 NonUniformArgsCase::NonUniformArgsCase(tcu::TestContext &testCtx, const std::string &name,
110 const NonUniformParams ¶ms)
111 : TestCase(testCtx, name)
112 , m_params(params)
113 {
114 }
115
checkSupport(Context & context) const116 void NonUniformArgsCase::checkSupport(Context &context) const
117 {
118 context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
119 context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
120 }
121
122 struct ArgsBufferData
123 {
124 tcu::Vec4 origin;
125 tcu::Vec4 direction;
126 float Tmin;
127 float Tmax;
128 uint32_t rayFlags;
129 uint32_t cullMask;
130 uint32_t sbtRecordOffset;
131 uint32_t sbtRecordStride;
132 uint32_t missIndex;
133 };
134
initPrograms(vk::SourceCollections & programCollection) const135 void NonUniformArgsCase::initPrograms(vk::SourceCollections &programCollection) const
136 {
137 const ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
138
139 std::ostringstream descriptors;
140 descriptors << "layout(set=0, binding=0) uniform accelerationStructureEXT topLevelAS;\n"
141 << "layout(set=0, binding=1, std430) buffer ArgumentsBlock {\n" // Must match ArgsBufferData.
142 << " vec4 origin;\n"
143 << " vec4 direction;\n"
144 << " float Tmin;\n"
145 << " float Tmax;\n"
146 << " uint rayFlags;\n"
147 << " uint cullMask;\n"
148 << " uint sbtRecordOffset;\n"
149 << " uint sbtRecordStride;\n"
150 << " uint missIndex;\n"
151 << "} args;\n"
152 << "layout(set=0, binding=2, std430) buffer ResultBlock {\n"
153 << " uint shaderId;\n"
154 << "} result;\n";
155 const auto descriptorsStr = descriptors.str();
156
157 std::ostringstream rgen;
158 rgen << "#version 460 core\n"
159 << "#extension GL_EXT_ray_tracing : require\n"
160 << "\n"
161 << descriptorsStr << "layout(location=0) rayPayloadEXT vec4 unused;\n"
162 << "\n"
163 << "void main()\n"
164 << "{\n"
165 << " traceRayEXT(topLevelAS,\n"
166 << " args.rayFlags,\n"
167 << " args.cullMask,\n"
168 << " args.sbtRecordOffset,\n"
169 << " args.sbtRecordStride,\n"
170 << " args.missIndex,\n"
171 << " args.origin.xyz,\n"
172 << " args.Tmin,\n"
173 << " args.direction.xyz,\n"
174 << " args.Tmax,\n"
175 << " 0);\n"
176 << "}\n";
177
178 std::ostringstream chit;
179 chit << "#version 460 core\n"
180 << "#extension GL_EXT_ray_tracing : require\n"
181 << "\n"
182 << descriptorsStr << "layout(constant_id=0) const uint chitShaderId = 0;\n"
183 << "layout(location=0) rayPayloadInEXT vec4 unused;\n"
184 << "\n"
185 << "void main()\n"
186 << "{\n"
187 << " result.shaderId = chitShaderId;\n"
188 << "}\n";
189
190 std::ostringstream miss;
191 miss << "#version 460 core\n"
192 << "#extension GL_EXT_ray_tracing : require\n"
193 << "\n"
194 << descriptorsStr << "layout(constant_id=0) const uint missShaderId = 0;\n"
195 << "layout(location=0) rayPayloadInEXT vec4 unused;\n"
196 << "\n"
197 << "void main()\n"
198 << "{\n"
199 << " result.shaderId = missShaderId;\n"
200 << "}\n";
201
202 programCollection.glslSources.add("rgen") << glu::RaygenSource(rgen.str()) << buildOptions;
203 programCollection.glslSources.add("chit") << glu::ClosestHitSource(chit.str()) << buildOptions;
204 programCollection.glslSources.add("miss") << glu::MissSource(miss.str()) << buildOptions;
205 }
206
createInstance(Context & context) const207 TestInstance *NonUniformArgsCase::createInstance(Context &context) const
208 {
209 return new NonUniformArgsInstance(context, m_params);
210 }
211
NonUniformArgsInstance(Context & context,const NonUniformParams & params)212 NonUniformArgsInstance::NonUniformArgsInstance(Context &context, const NonUniformParams ¶ms)
213 : TestInstance(context)
214 , m_params(params)
215 {
216 }
217
joinMostLeast(uint32_t most,uint32_t least)218 uint32_t joinMostLeast(uint32_t most, uint32_t least)
219 {
220 constexpr auto kMaxUint16 = static_cast<uint32_t>(std::numeric_limits<uint16_t>::max());
221 DE_UNREF(kMaxUint16); // For release builds.
222 DE_ASSERT(most <= kMaxUint16 && least <= kMaxUint16);
223 return ((most << 16) | least);
224 }
225
makeMissId(uint32_t missIndex)226 uint32_t makeMissId(uint32_t missIndex)
227 {
228 // 1 on the highest 16 bits for miss shaders.
229 return joinMostLeast(1u, missIndex);
230 }
231
makeChitId(uint32_t chitIndex)232 uint32_t makeChitId(uint32_t chitIndex)
233 {
234 // 2 on the highest 16 bits for closest hit shaders.
235 return joinMostLeast(2u, chitIndex);
236 }
237
iterate(void)238 tcu::TestStatus NonUniformArgsInstance::iterate(void)
239 {
240 const auto &vki = m_context.getInstanceInterface();
241 const auto physDev = m_context.getPhysicalDevice();
242 const auto &vkd = m_context.getDeviceInterface();
243 const auto device = m_context.getDevice();
244 auto &alloc = m_context.getDefaultAllocator();
245 const auto qIndex = m_context.getUniversalQueueFamilyIndex();
246 const auto queue = m_context.getUniversalQueue();
247 const auto stages =
248 (VK_SHADER_STAGE_RAYGEN_BIT_KHR | VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_MISS_BIT_KHR);
249
250 // Geometry data constants.
251 const std::vector<tcu::Vec3> kOffscreenTriangle = {
252 // Triangle around (x=0, y=2) z=-5
253 tcu::Vec3(0.0f, 2.5f, -5.0f),
254 tcu::Vec3(-0.5f, 1.5f, -5.0f),
255 tcu::Vec3(0.5f, 1.5f, -5.0f),
256 };
257 const std::vector<tcu::Vec3> kOnscreenTriangle = {
258 // Triangle around (x=0, y=2) z=5
259 tcu::Vec3(0.0f, 2.5f, 5.0f),
260 tcu::Vec3(-0.5f, 1.5f, 5.0f),
261 tcu::Vec3(0.5f, 1.5f, 5.0f),
262 };
263 const tcu::Vec4 kGoodOrigin(0.0f, 2.0f, 0.0f, 0.0f); // Around (x=0, y=2) z=0.
264 const tcu::Vec4 kBadOrigin(0.0f, 8.0f, 0.0f, 0.0f); // Too high, around (x=0, y=8) depth 0.
265 const tcu::Vec4 kGoodDirection(0.0f, 0.0f, 1.0f, 0.0f); // Towards +z.
266 const tcu::Vec4 kBadDirection(1.0f, 0.0f, 0.0f, 0.0f); // Towards +x.
267 const float kGoodTmin = 4.0f; // Good to travel from z=0 to z=5.
268 const float kGoodTmax = 6.0f; // Ditto.
269 const float kBadTmin = 5.5f; // Tmin after triangle.
270 const float kBadTmax = 4.5f; // Tmax before triangle.
271 const uint32_t kGoodFlags = 0u; // MaskNone
272 const uint32_t kBadFlags = 256u; // SkipTrianglesKHR
273 const uint32_t kGoodCullMask = 0x0Fu; // Matches instance.
274 const uint32_t kBadCullMask = 0xF0u; // Does not match instance.
275
276 // Command pool and buffer.
277 const auto cmdPool = makeCommandPool(vkd, device, qIndex);
278 const auto cmdBufferPtr = allocateCommandBuffer(vkd, device, cmdPool.get(), VK_COMMAND_BUFFER_LEVEL_PRIMARY);
279 const auto cmdBuffer = cmdBufferPtr.get();
280
281 beginCommandBuffer(vkd, cmdBuffer);
282
283 // Build acceleration structures.
284 auto topLevelAS = makeTopLevelAccelerationStructure();
285 auto bottomLevelAS = makeBottomLevelAccelerationStructure();
286
287 // Putting the offscreen triangle first makes sure hits have a geometryIndex=1, meaning sbtRecordStride matters.
288 std::vector<const std::vector<tcu::Vec3> *> geometries;
289 geometries.push_back(&kOffscreenTriangle);
290 geometries.push_back(&kOnscreenTriangle);
291
292 for (const auto &geometryPtr : geometries)
293 bottomLevelAS->addGeometry(*geometryPtr, true /* is triangles */);
294
295 bottomLevelAS->createAndBuild(vkd, device, cmdBuffer, alloc);
296
297 de::SharedPtr<BottomLevelAccelerationStructure> blasSharedPtr(bottomLevelAS.release());
298 topLevelAS->setInstanceCount(1);
299 topLevelAS->addInstance(blasSharedPtr, identityMatrix3x4, 0u, kGoodCullMask, 0u,
300 VK_GEOMETRY_INSTANCE_TRIANGLE_FACING_CULL_DISABLE_BIT_KHR);
301 topLevelAS->createAndBuild(vkd, device, cmdBuffer, alloc);
302
303 // Input storage buffer.
304 const auto inputBufferSize = static_cast<VkDeviceSize>(sizeof(ArgsBufferData));
305 const auto inputBufferInfo = makeBufferCreateInfo(inputBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
306 BufferWithMemory inputBuffer(vkd, device, alloc, inputBufferInfo, MemoryRequirement::HostVisible);
307 auto &inputBufferAlloc = inputBuffer.getAllocation();
308
309 // Output storage buffer.
310 const auto outputBufferSize = static_cast<VkDeviceSize>(sizeof(uint32_t));
311 const auto outputBufferInfo = makeBufferCreateInfo(outputBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
312 BufferWithMemory outputBuffer(vkd, device, alloc, outputBufferInfo, MemoryRequirement::HostVisible);
313 auto &outputBufferAlloc = outputBuffer.getAllocation();
314
315 // Fill output buffer with an initial value.
316 deMemset(outputBufferAlloc.getHostPtr(), 0, static_cast<size_t>(outputBufferSize));
317 flushAlloc(vkd, device, outputBufferAlloc);
318
319 // Descriptor set layout and pipeline layout.
320 DescriptorSetLayoutBuilder setLayoutBuilder;
321 setLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, stages);
322 setLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, stages);
323 setLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, stages);
324 const auto setLayout = setLayoutBuilder.build(vkd, device);
325 const auto pipelineLayout = makePipelineLayout(vkd, device, setLayout.get());
326
327 // Descriptor pool and set.
328 DescriptorPoolBuilder poolBuilder;
329 poolBuilder.addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR);
330 poolBuilder.addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 2u);
331 const auto descriptorPool = poolBuilder.build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
332 const auto descriptorSet = makeDescriptorSet(vkd, device, descriptorPool.get(), setLayout.get());
333
334 // Update descriptor set.
335 {
336 const VkWriteDescriptorSetAccelerationStructureKHR accelDescInfo = {
337 VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR,
338 nullptr,
339 1u,
340 topLevelAS.get()->getPtr(),
341 };
342
343 const auto inputBufferDescInfo = makeDescriptorBufferInfo(inputBuffer.get(), 0ull, VK_WHOLE_SIZE);
344 const auto outputBufferDescInfo = makeDescriptorBufferInfo(outputBuffer.get(), 0ull, VK_WHOLE_SIZE);
345
346 DescriptorSetUpdateBuilder updateBuilder;
347 updateBuilder.writeSingle(descriptorSet.get(), DescriptorSetUpdateBuilder::Location::binding(0u),
348 VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &accelDescInfo);
349 updateBuilder.writeSingle(descriptorSet.get(), DescriptorSetUpdateBuilder::Location::binding(1u),
350 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &inputBufferDescInfo);
351 updateBuilder.writeSingle(descriptorSet.get(), DescriptorSetUpdateBuilder::Location::binding(2u),
352 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &outputBufferDescInfo);
353 updateBuilder.update(vkd, device);
354 }
355
356 // Shader modules.
357 auto rgenModule = makeVkSharedPtr(createShaderModule(vkd, device, m_context.getBinaryCollection().get("rgen"), 0));
358 auto missModule = makeVkSharedPtr(createShaderModule(vkd, device, m_context.getBinaryCollection().get("miss"), 0));
359 auto chitModule = makeVkSharedPtr(createShaderModule(vkd, device, m_context.getBinaryCollection().get("chit"), 0));
360
361 // Get some ray tracing properties.
362 uint32_t shaderGroupHandleSize = 0u;
363 uint32_t shaderGroupBaseAlignment = 1u;
364 {
365 const auto rayTracingPropertiesKHR = makeRayTracingProperties(vki, physDev);
366 shaderGroupHandleSize = rayTracingPropertiesKHR->getShaderGroupHandleSize();
367 shaderGroupBaseAlignment = rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
368 }
369
370 // Create raytracing pipeline and shader binding tables.
371 Move<VkPipeline> pipeline;
372
373 de::MovePtr<BufferWithMemory> raygenSBT;
374 de::MovePtr<BufferWithMemory> missSBT;
375 de::MovePtr<BufferWithMemory> hitSBT;
376 de::MovePtr<BufferWithMemory> callableSBT;
377
378 VkStridedDeviceAddressRegionKHR raygenSBTRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
379 VkStridedDeviceAddressRegionKHR missSBTRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
380 VkStridedDeviceAddressRegionKHR hitSBTRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
381 VkStridedDeviceAddressRegionKHR callableSBTRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
382
383 // Generate ids for the closest hit and miss shaders according to the test parameters.
384 DE_ASSERT(m_params.hitParams.rayTypeCount > 0u);
385 DE_ASSERT(m_params.hitParams.rayType < m_params.hitParams.rayTypeCount);
386 DE_ASSERT(geometries.size() > 0u);
387
388 std::vector<uint32_t> missShaderIds;
389 for (uint32_t missIdx = 0; missIdx <= m_params.missParams.missIndex; ++missIdx)
390 missShaderIds.push_back(makeMissId(missIdx));
391
392 uint32_t chitCounter = 0u;
393 std::vector<uint32_t> chitShaderIds;
394
395 for (size_t geoIdx = 0; geoIdx < geometries.size(); ++geoIdx)
396 for (uint32_t rayIdx = 0; rayIdx < m_params.hitParams.rayTypeCount; ++rayIdx)
397 chitShaderIds.push_back(makeChitId(chitCounter++));
398
399 {
400 const auto rayTracingPipeline = de::newMovePtr<RayTracingPipeline>();
401 const VkSpecializationMapEntry specializationMapEntry = {
402 0u, // uint32_t constantID;
403 0u, // uint32_t offset;
404 static_cast<uintptr_t>(sizeof(uint32_t)), // uintptr_t size;
405 };
406 VkSpecializationInfo specInfo = {
407 1u, // uint32_t mapEntryCount;
408 &specializationMapEntry, // const VkSpecializationMapEntry* pMapEntries;
409 static_cast<uintptr_t>(sizeof(uint32_t)), // uintptr_t dataSize;
410 nullptr, // const void* pData;
411 };
412
413 std::vector<VkSpecializationInfo> specInfos;
414 specInfos.reserve(missShaderIds.size() + chitShaderIds.size());
415
416 uint32_t shaderGroupIdx = 0u;
417 rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR, rgenModule, shaderGroupIdx++);
418
419 for (size_t missIdx = 0; missIdx < missShaderIds.size(); ++missIdx)
420 {
421 specInfo.pData = &missShaderIds.at(missIdx);
422 specInfos.push_back(specInfo);
423 rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR, missModule, shaderGroupIdx++,
424 &specInfos.back());
425 }
426
427 const auto firstChitGroup = shaderGroupIdx;
428
429 for (size_t chitIdx = 0; chitIdx < chitShaderIds.size(); ++chitIdx)
430 {
431 specInfo.pData = &chitShaderIds.at(chitIdx);
432 specInfos.push_back(specInfo);
433 rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR, chitModule, shaderGroupIdx++,
434 &specInfos.back());
435 }
436
437 pipeline = rayTracingPipeline->createPipeline(vkd, device, pipelineLayout.get());
438
439 raygenSBT = rayTracingPipeline->createShaderBindingTable(
440 vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 0u, 1u);
441 raygenSBTRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenSBT->get(), 0),
442 shaderGroupHandleSize, shaderGroupHandleSize);
443
444 missSBT = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc,
445 shaderGroupHandleSize, shaderGroupBaseAlignment, 1u,
446 static_cast<uint32_t>(missShaderIds.size()));
447 missSBTRegion =
448 makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missSBT->get(), 0),
449 shaderGroupHandleSize, shaderGroupHandleSize * missShaderIds.size());
450
451 hitSBT = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc, shaderGroupHandleSize,
452 shaderGroupBaseAlignment, firstChitGroup,
453 static_cast<uint32_t>(chitShaderIds.size()));
454 hitSBTRegion =
455 makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitSBT->get(), 0),
456 shaderGroupHandleSize, shaderGroupHandleSize * chitShaderIds.size());
457 }
458
459 // Fill input buffer values.
460 {
461 DE_ASSERT(!(m_params.miss && m_params.missParams.missCause == MissCause::NONE));
462
463 const ArgsBufferData argsBufferData = {
464 ((m_params.miss && m_params.missParams.missCause == MissCause::ORIGIN) ? kBadOrigin : kGoodOrigin),
465 ((m_params.miss && m_params.missParams.missCause == MissCause::DIRECTION) ? kBadDirection : kGoodDirection),
466 ((m_params.miss && m_params.missParams.missCause == MissCause::TMIN) ? kBadTmin : kGoodTmin),
467 ((m_params.miss && m_params.missParams.missCause == MissCause::TMAX) ? kBadTmax : kGoodTmax),
468 ((m_params.miss && m_params.missParams.missCause == MissCause::FLAGS) ? kBadFlags : kGoodFlags),
469 ((m_params.miss && m_params.missParams.missCause == MissCause::CULL_MASK) ? kBadCullMask : kGoodCullMask),
470 m_params.hitParams.rayType,
471 m_params.hitParams.rayTypeCount,
472 m_params.missParams.missIndex,
473 };
474
475 deMemcpy(inputBufferAlloc.getHostPtr(), &argsBufferData, sizeof(argsBufferData));
476 flushAlloc(vkd, device, inputBufferAlloc);
477 }
478
479 // Trace rays.
480 vkd.cmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipeline.get());
481 vkd.cmdBindDescriptorSets(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipelineLayout.get(), 0u, 1u,
482 &descriptorSet.get(), 0u, nullptr);
483 vkd.cmdTraceRaysKHR(cmdBuffer, &raygenSBTRegion, &missSBTRegion, &hitSBTRegion, &callableSBTRegion, 1u, 1u, 1u);
484
485 // Barrier for the output buffer.
486 const auto bufferBarrier = makeMemoryBarrier(VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
487 vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR, VK_PIPELINE_STAGE_HOST_BIT, 0u, 1u,
488 &bufferBarrier, 0u, nullptr, 0u, nullptr);
489
490 endCommandBuffer(vkd, cmdBuffer);
491 submitCommandsAndWait(vkd, device, queue, cmdBuffer);
492
493 // Check output value.
494 invalidateAlloc(vkd, device, outputBufferAlloc);
495 uint32_t outputVal = std::numeric_limits<uint32_t>::max();
496 deMemcpy(&outputVal, outputBufferAlloc.getHostPtr(), sizeof(outputVal));
497 const auto expectedVal = (m_params.miss ? makeMissId(m_params.missParams.missIndex) :
498 makeChitId(m_params.hitParams.rayTypeCount + m_params.hitParams.rayType));
499
500 std::ostringstream msg;
501 msg << "Output value: 0x" << std::hex << outputVal << " (expected 0x" << expectedVal << ")";
502
503 if (outputVal != expectedVal)
504 return tcu::TestStatus::fail(msg.str());
505
506 auto &log = m_context.getTestContext().getLog();
507 log << tcu::TestLog::Message << msg.str() << tcu::TestLog::EndMessage;
508
509 return tcu::TestStatus::pass("Pass");
510 }
511
512 } // namespace
513
createNonUniformArgsTests(tcu::TestContext & testCtx)514 tcu::TestCaseGroup *createNonUniformArgsTests(tcu::TestContext &testCtx)
515 {
516 // Test non-uniform arguments in traceRayExt()
517 de::MovePtr<tcu::TestCaseGroup> nonUniformGroup(new tcu::TestCaseGroup(testCtx, "non_uniform_args"));
518
519 // Closest hit cases.
520 {
521 NonUniformParams params;
522 params.miss = false;
523 params.missParams.missIndex = 0u;
524 params.missParams.missCause = MissCause::NONE;
525
526 for (uint32_t typeCount = 1u; typeCount <= 4u; ++typeCount)
527 {
528 params.hitParams.rayTypeCount = typeCount;
529 for (uint32_t rayType = 0u; rayType < typeCount; ++rayType)
530 {
531 params.hitParams.rayType = rayType;
532 nonUniformGroup->addChild(new NonUniformArgsCase(
533 testCtx, "chit_" + de::toString(typeCount) + "_types_" + de::toString(rayType), params));
534 }
535 }
536 }
537
538 // Miss cases.
539 {
540 NonUniformParams params;
541 params.miss = true;
542 params.hitParams.rayTypeCount = 1u;
543 params.hitParams.rayType = 0u;
544
545 for (int causeIdx = static_cast<int>(MissCause::NONE) + 1; causeIdx < static_cast<int>(MissCause::CAUSE_COUNT);
546 ++causeIdx)
547 {
548 params.missParams.missCause = static_cast<MissCause>(causeIdx);
549 params.missParams.missIndex = static_cast<uint32_t>(causeIdx - 1);
550 nonUniformGroup->addChild(new NonUniformArgsCase(testCtx, "miss_cause_" + de::toString(causeIdx), params));
551 }
552 }
553
554 return nonUniformGroup.release();
555 }
556
557 } // namespace RayTracing
558 } // namespace vkt
559