1 /*-------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2022 The Khronos Group Inc.
6  * Copyright (c) 2022 NVIDIA Corporation.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  *//*!
21  * \file
22  * \brief Ray Query Position Fetch Tests
23  *//*--------------------------------------------------------------------*/
24 
25 #include "vktRayQueryPositionFetchTests.hpp"
26 #include "vktTestCase.hpp"
27 
28 #include "vkRayTracingUtil.hpp"
29 #include "vkObjUtil.hpp"
30 #include "vkCmdUtil.hpp"
31 #include "vkBufferWithMemory.hpp"
32 #include "vkBuilderUtil.hpp"
33 #include "vkTypeUtil.hpp"
34 #include "vkBarrierUtil.hpp"
35 #include "vktTestGroupUtil.hpp"
36 
37 #include "deUniquePtr.hpp"
38 #include "deRandom.hpp"
39 
40 #include "tcuVectorUtil.hpp"
41 
42 #include <sstream>
43 #include <vector>
44 #include <iostream>
45 
46 namespace vkt
47 {
48 namespace RayQuery
49 {
50 
51 namespace
52 {
53 
54 using namespace vk;
55 
56 enum ShaderSourcePipeline
57 {
58     SSP_GRAPHICS_PIPELINE,
59     SSP_COMPUTE_PIPELINE,
60     SSP_RAY_TRACING_PIPELINE
61 };
62 
63 enum ShaderSourceType
64 {
65     SST_VERTEX_SHADER,
66     SST_COMPUTE_SHADER,
67     SST_RAY_GENERATION_SHADER,
68 };
69 
70 enum TestFlagBits
71 {
72     TEST_FLAG_BIT_INSTANCE_TRANSFORM = 1U << 0,
73     TEST_FLAG_BIT_LAST               = 1U << 1,
74 };
75 
76 std::vector<std::string> testFlagBitNames = {
77     "instance_transform",
78 };
79 
80 struct TestParams
81 {
82     ShaderSourceType shaderSourceType;
83     ShaderSourcePipeline shaderSourcePipeline;
84     vk::VkAccelerationStructureBuildTypeKHR buildType; // are we making AS on CPU or GPU
85     VkFormat vertexFormat;
86     uint32_t testFlagMask;
87 };
88 
89 static constexpr uint32_t kNumThreadsAtOnce = 128;
90 
91 class PositionFetchCase : public TestCase
92 {
93 public:
94     PositionFetchCase(tcu::TestContext &testCtx, const std::string &name, const TestParams &params);
~PositionFetchCase(void)95     virtual ~PositionFetchCase(void)
96     {
97     }
98 
99     virtual void checkSupport(Context &context) const;
100     virtual void initPrograms(vk::SourceCollections &programCollection) const;
101     virtual TestInstance *createInstance(Context &context) const;
102 
103 protected:
104     TestParams m_params;
105 };
106 
107 class PositionFetchInstance : public TestInstance
108 {
109 public:
110     PositionFetchInstance(Context &context, const TestParams &params);
~PositionFetchInstance(void)111     virtual ~PositionFetchInstance(void)
112     {
113     }
114 
115     virtual tcu::TestStatus iterate(void);
116 
117 protected:
118     TestParams m_params;
119 };
120 
PositionFetchCase(tcu::TestContext & testCtx,const std::string & name,const TestParams & params)121 PositionFetchCase::PositionFetchCase(tcu::TestContext &testCtx, const std::string &name, const TestParams &params)
122     : TestCase(testCtx, name)
123     , m_params(params)
124 {
125 }
126 
checkSupport(Context & context) const127 void PositionFetchCase::checkSupport(Context &context) const
128 {
129     context.requireDeviceFunctionality("VK_KHR_ray_query");
130     context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
131     context.requireDeviceFunctionality("VK_KHR_ray_tracing_position_fetch");
132 
133     const VkPhysicalDeviceRayQueryFeaturesKHR &rayQueryFeaturesKHR = context.getRayQueryFeatures();
134     if (rayQueryFeaturesKHR.rayQuery == false)
135         TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceRayQueryFeaturesKHR.rayQuery");
136 
137     const VkPhysicalDeviceAccelerationStructureFeaturesKHR &accelerationStructureFeaturesKHR =
138         context.getAccelerationStructureFeatures();
139     if (accelerationStructureFeaturesKHR.accelerationStructure == false)
140         TCU_THROW(TestError,
141                   "VK_KHR_ray_query requires VkPhysicalDeviceAccelerationStructureFeaturesKHR.accelerationStructure");
142 
143     if (m_params.buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_HOST_KHR &&
144         accelerationStructureFeaturesKHR.accelerationStructureHostCommands == false)
145         TCU_THROW(NotSupportedError,
146                   "Requires VkPhysicalDeviceAccelerationStructureFeaturesKHR.accelerationStructureHostCommands");
147 
148     const VkPhysicalDeviceRayTracingPositionFetchFeaturesKHR &rayTracingPositionFetchFeaturesKHR =
149         context.getRayTracingPositionFetchFeatures();
150     if (rayTracingPositionFetchFeaturesKHR.rayTracingPositionFetch == false)
151         TCU_THROW(NotSupportedError, "Requires VkPhysicalDevicePositionFetchFeaturesKHR.rayTracingPositionFetch");
152 
153     // Check supported vertex format.
154     checkAccelerationStructureVertexBufferFormat(context.getInstanceInterface(), context.getPhysicalDevice(),
155                                                  m_params.vertexFormat);
156 
157     if (m_params.shaderSourceType == SST_RAY_GENERATION_SHADER)
158     {
159         context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
160 
161         const VkPhysicalDeviceRayTracingPipelineFeaturesKHR &rayTracingPipelineFeaturesKHR =
162             context.getRayTracingPipelineFeatures();
163 
164         if (rayTracingPipelineFeaturesKHR.rayTracingPipeline == false)
165             TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceRayTracingPipelineFeaturesKHR.rayTracingPipeline");
166     }
167 
168     if (m_params.shaderSourceType == SST_RAY_GENERATION_SHADER || m_params.shaderSourceType == SST_COMPUTE_SHADER)
169     {
170         const VkPhysicalDeviceLimits &deviceLimits = context.getDeviceProperties().limits;
171         if (kNumThreadsAtOnce > deviceLimits.maxComputeWorkGroupSize[0])
172         {
173             TCU_THROW(NotSupportedError, "Compute workgroup size exceeds device limit");
174         }
175     }
176 
177     switch (m_params.shaderSourceType)
178     {
179     case SST_VERTEX_SHADER:
180         context.requireDeviceCoreFeature(DEVICE_CORE_FEATURE_VERTEX_PIPELINE_STORES_AND_ATOMICS);
181         break;
182     default:
183         break;
184     }
185 }
186 
initPrograms(vk::SourceCollections & programCollection) const187 void PositionFetchCase::initPrograms(vk::SourceCollections &programCollection) const
188 {
189     const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
190 
191     uint32_t numRays = 1; // XXX
192 
193     std::ostringstream sharedHeader;
194     sharedHeader << "#version 460 core\n"
195                  << "#extension GL_EXT_ray_query : require\n"
196                  << "#extension GL_EXT_ray_tracing_position_fetch : require\n"
197                  << "\n"
198                  << "layout(set=0, binding=0) uniform accelerationStructureEXT topLevelAS;\n"
199                  << "layout(set=0, binding=1, std430) buffer RayOrigins {\n"
200                  << "  vec4 values[" << numRays << "];\n"
201                  << "} origins;\n"
202                  << "layout(set=0, binding=2, std430) buffer OutputPositions {\n"
203                  << "  vec4 values[" << 3 * numRays << "];\n"
204                  << "} modes;\n";
205 
206     std::ostringstream mainLoop;
207     mainLoop
208         << "  while (index < " << numRays
209         << ") {\n"
210         //<< "     for (int i=0; i<3; i++) {\n"
211         //<< "       modes.values[3*index.x+i] = vec4(i, 0.0, 5.0, 1.0);\n"
212         //<< "     }\n"
213         << "    const uint  cullMask  = 0xFF;\n"
214         << "    const vec3  origin    = origins.values[index].xyz;\n"
215         << "    const vec3  direction = vec3(0.0, 0.0, -1.0);\n"
216         << "    const float tMin      = 0.0f;\n"
217         << "    const float tMax      = 2.0f;\n"
218         << "    rayQueryEXT rq;\n"
219         << "    rayQueryInitializeEXT(rq, topLevelAS, gl_RayFlagsNoneEXT, cullMask, origin, tMin, direction, tMax);\n"
220         << "    while (rayQueryProceedEXT(rq)) {\n"
221         << "      if (rayQueryGetIntersectionTypeEXT(rq, false) == gl_RayQueryCandidateIntersectionTriangleEXT) {\n"
222         << "        vec3 outputVal[3];\n"
223         << "        rayQueryGetIntersectionTriangleVertexPositionsEXT(rq, false, outputVal);\n"
224         << "        for (int i=0; i<3; i++) {\n"
225         << "           modes.values[3*index.x+i] = vec4(outputVal[i], 0);\n"
226         //        << "           modes.values[3*index.x+i] = vec4(1.0, 1.0, 1.0, 0);\n"
227         << "        }\n"
228         << "      }\n"
229         << "    }\n"
230         << "    index += " << kNumThreadsAtOnce << ";\n"
231         << "  }\n";
232 
233     if (m_params.shaderSourceType == SST_VERTEX_SHADER)
234     {
235         std::ostringstream vert;
236         vert << sharedHeader.str() << "void main()\n"
237              << "{\n"
238              << "  uint index             = gl_VertexIndex.x;\n"
239              << mainLoop.str() << "}\n";
240 
241         programCollection.glslSources.add("vert") << glu::VertexSource(vert.str()) << buildOptions;
242     }
243     else if (m_params.shaderSourceType == SST_RAY_GENERATION_SHADER)
244     {
245         std::ostringstream rgen;
246         rgen << sharedHeader.str() << "#extension GL_EXT_ray_tracing : require\n"
247              << "void main()\n"
248              << "{\n"
249              << "  uint index             = gl_LaunchIDEXT.x;\n"
250              << mainLoop.str() << "}\n";
251 
252         programCollection.glslSources.add("rgen")
253             << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
254     }
255     else
256     {
257         DE_ASSERT(m_params.shaderSourceType == SST_COMPUTE_SHADER);
258         std::ostringstream comp;
259         comp << sharedHeader.str() << "layout(local_size_x=" << kNumThreadsAtOnce
260              << ", local_size_y=1, local_size_z=1) in;\n"
261              << "\n"
262              << "void main()\n"
263              << "{\n"
264              << "  uint index             = gl_LocalInvocationID.x;\n"
265              << mainLoop.str() << "}\n";
266 
267         programCollection.glslSources.add("comp")
268             << glu::ComputeSource(updateRayTracingGLSL(comp.str())) << buildOptions;
269     }
270 }
271 
createInstance(Context & context) const272 TestInstance *PositionFetchCase::createInstance(Context &context) const
273 {
274     return new PositionFetchInstance(context, m_params);
275 }
276 
PositionFetchInstance(Context & context,const TestParams & params)277 PositionFetchInstance::PositionFetchInstance(Context &context, const TestParams &params)
278     : TestInstance(context)
279     , m_params(params)
280 {
281 }
282 
makeEmptyRenderPass(const DeviceInterface & vk,const VkDevice device)283 static Move<VkRenderPass> makeEmptyRenderPass(const DeviceInterface &vk, const VkDevice device)
284 {
285     std::vector<VkSubpassDescription> subpassDescriptions;
286 
287     const VkSubpassDescription description = {
288         (VkSubpassDescriptionFlags)0,    //  VkSubpassDescriptionFlags flags;
289         VK_PIPELINE_BIND_POINT_GRAPHICS, //  VkPipelineBindPoint pipelineBindPoint;
290         0u,                              //  uint32_t inputAttachmentCount;
291         DE_NULL,                         //  const VkAttachmentReference* pInputAttachments;
292         0u,                              //  uint32_t colorAttachmentCount;
293         DE_NULL,                         //  const VkAttachmentReference* pColorAttachments;
294         DE_NULL,                         //  const VkAttachmentReference* pResolveAttachments;
295         DE_NULL,                         //  const VkAttachmentReference* pDepthStencilAttachment;
296         0,                               //  uint32_t preserveAttachmentCount;
297         DE_NULL                          //  const uint32_t* pPreserveAttachments;
298     };
299     subpassDescriptions.push_back(description);
300 
301     const VkRenderPassCreateInfo renderPassInfo = {
302         VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO,         //  VkStructureType sType;
303         DE_NULL,                                           //  const void* pNext;
304         static_cast<VkRenderPassCreateFlags>(0u),          //  VkRenderPassCreateFlags flags;
305         0u,                                                //  uint32_t attachmentCount;
306         DE_NULL,                                           //  const VkAttachmentDescription* pAttachments;
307         static_cast<uint32_t>(subpassDescriptions.size()), //  uint32_t subpassCount;
308         &subpassDescriptions[0],                           //  const VkSubpassDescription* pSubpasses;
309         0u,                                                //  uint32_t dependencyCount;
310         DE_NULL                                            //  const VkSubpassDependency* pDependencies;
311     };
312 
313     return createRenderPass(vk, device, &renderPassInfo);
314 }
315 
makeFramebuffer(const DeviceInterface & vk,const VkDevice device,VkRenderPass renderPass,uint32_t width,uint32_t height)316 static Move<VkFramebuffer> makeFramebuffer(const DeviceInterface &vk, const VkDevice device, VkRenderPass renderPass,
317                                            uint32_t width, uint32_t height)
318 {
319     const vk::VkFramebufferCreateInfo framebufferParams = {
320         vk::VK_STRUCTURE_TYPE_FRAMEBUFFER_CREATE_INFO, // sType
321         DE_NULL,                                       // pNext
322         (vk::VkFramebufferCreateFlags)0,
323         renderPass, // renderPass
324         0u,         // attachmentCount
325         DE_NULL,    // pAttachments
326         width,      // width
327         height,     // height
328         1u,         // layers
329     };
330 
331     return createFramebuffer(vk, device, &framebufferParams);
332 }
333 
makeGraphicsPipeline(const DeviceInterface & vk,const VkDevice device,const VkPipelineLayout pipelineLayout,const VkRenderPass renderPass,const VkShaderModule vertexModule,const uint32_t subpass)334 Move<VkPipeline> makeGraphicsPipeline(const DeviceInterface &vk, const VkDevice device,
335                                       const VkPipelineLayout pipelineLayout, const VkRenderPass renderPass,
336                                       const VkShaderModule vertexModule, const uint32_t subpass)
337 {
338     VkExtent2D renderSize{256, 256};
339     VkViewport viewport = makeViewport(renderSize);
340     VkRect2D scissor    = makeRect2D(renderSize);
341 
342     const VkPipelineViewportStateCreateInfo viewportStateCreateInfo = {
343         VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO, // VkStructureType                             sType
344         DE_NULL,                                               // const void*                                 pNext
345         (VkPipelineViewportStateCreateFlags)0,                 // VkPipelineViewportStateCreateFlags          flags
346         1u,        // uint32_t                                    viewportCount
347         &viewport, // const VkViewport*                           pViewports
348         1u,        // uint32_t                                    scissorCount
349         &scissor   // const VkRect2D*                             pScissors
350     };
351 
352     const VkPipelineInputAssemblyStateCreateInfo inputAssemblyStateCreateInfo = {
353         VK_STRUCTURE_TYPE_PIPELINE_INPUT_ASSEMBLY_STATE_CREATE_INFO, // VkStructureType                            sType
354         DE_NULL,                                                     // const void*                                pNext
355         0u,                                                          // VkPipelineInputAssemblyStateCreateFlags    flags
356         VK_PRIMITIVE_TOPOLOGY_POINT_LIST, // VkPrimitiveTopology                        topology
357         VK_FALSE                          // VkBool32                                   primitiveRestartEnable
358     };
359 
360     const VkPipelineVertexInputStateCreateInfo vertexInputStateCreateInfo = {
361         VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO, //  VkStructureType                                    sType
362         DE_NULL,                                  //  const void*                                        pNext
363         (VkPipelineVertexInputStateCreateFlags)0, //  VkPipelineVertexInputStateCreateFlags            flags
364         0u,      //  uint32_t                                        vertexBindingDescriptionCount
365         DE_NULL, //  const VkVertexInputBindingDescription*            pVertexBindingDescriptions
366         0u,      //  uint32_t                                        vertexAttributeDescriptionCount
367         DE_NULL, //  const VkVertexInputAttributeDescription*        pVertexAttributeDescriptions
368     };
369 
370     const VkPipelineRasterizationStateCreateInfo rasterizationStateCreateInfo = {
371         VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_STATE_CREATE_INFO, //  VkStructureType                            sType
372         DE_NULL,                                                    //  const void*                                pNext
373         0u,                                                         //  VkPipelineRasterizationStateCreateFlags    flags
374         VK_FALSE,                        //  VkBool32                                depthClampEnable
375         VK_TRUE,                         //  VkBool32                                rasterizerDiscardEnable
376         VK_POLYGON_MODE_FILL,            //  VkPolygonMode                            polygonMode
377         VK_CULL_MODE_NONE,               //  VkCullModeFlags                            cullMode
378         VK_FRONT_FACE_COUNTER_CLOCKWISE, //  VkFrontFace                                frontFace
379         VK_FALSE,                        //  VkBool32                                depthBiasEnable
380         0.0f,                            //  float                                    depthBiasConstantFactor
381         0.0f,                            //  float                                    depthBiasClamp
382         0.0f,                            //  float                                    depthBiasSlopeFactor
383         1.0f                             //  float                                    lineWidth
384     };
385 
386     return makeGraphicsPipeline(
387         vk,                             // const DeviceInterface&                            vk
388         device,                         // const VkDevice                                    device
389         pipelineLayout,                 // const VkPipelineLayout                            pipelineLayout
390         vertexModule,                   // const VkShaderModule                                vertexShaderModule
391         DE_NULL,                        // const VkShaderModule                                tessellationControlModule
392         DE_NULL,                        // const VkShaderModule                                tessellationEvalModule
393         DE_NULL,                        // const VkShaderModule                                geometryShaderModule
394         DE_NULL,                        // const VkShaderModule                                fragmentShaderModule
395         renderPass,                     // const VkRenderPass                                renderPass
396         subpass,                        // const uint32_t                                    subpass
397         &vertexInputStateCreateInfo,    // const VkPipelineVertexInputStateCreateInfo*        vertexInputStateCreateInfo
398         &inputAssemblyStateCreateInfo,  // const VkPipelineInputAssemblyStateCreateInfo*    inputAssemblyStateCreateInfo
399         DE_NULL,                        // const VkPipelineTessellationStateCreateInfo*        tessStateCreateInfo
400         &viewportStateCreateInfo,       // const VkPipelineViewportStateCreateInfo*            viewportStateCreateInfo
401         &rasterizationStateCreateInfo); // const VkPipelineRasterizationStateCreateInfo*    rasterizationStateCreateInfo
402 }
403 
iterate(void)404 tcu::TestStatus PositionFetchInstance::iterate(void)
405 {
406     const auto &vkd   = m_context.getDeviceInterface();
407     const auto device = m_context.getDevice();
408     auto &alloc       = m_context.getDefaultAllocator();
409     const auto qIndex = m_context.getUniversalQueueFamilyIndex();
410     const auto queue  = m_context.getUniversalQueue();
411 
412     // Command pool and buffer.
413     const auto cmdPool      = makeCommandPool(vkd, device, qIndex);
414     const auto cmdBufferPtr = allocateCommandBuffer(vkd, device, cmdPool.get(), VK_COMMAND_BUFFER_LEVEL_PRIMARY);
415     const auto cmdBuffer    = cmdBufferPtr.get();
416 
417     beginCommandBuffer(vkd, cmdBuffer);
418 
419     // Build acceleration structures.
420     auto topLevelAS    = makeTopLevelAccelerationStructure();
421     auto bottomLevelAS = makeBottomLevelAccelerationStructure();
422 
423     const std::vector<tcu::Vec3> triangle = {
424         tcu::Vec3(0.0f, 0.0f, 0.0f),
425         tcu::Vec3(1.0f, 0.0f, 0.0f),
426         tcu::Vec3(0.0f, 1.0f, 0.0f),
427     };
428 
429     const VkTransformMatrixKHR notQuiteIdentityMatrix3x4 = {
430         {{0.98f, 0.0f, 0.0f, 0.0f}, {0.0f, 0.97f, 0.0f, 0.0f}, {0.0f, 0.0f, 0.99f, 0.0f}}};
431 
432     de::SharedPtr<RaytracedGeometryBase> geometry =
433         makeRaytracedGeometry(VK_GEOMETRY_TYPE_TRIANGLES_KHR, m_params.vertexFormat, VK_INDEX_TYPE_NONE_KHR);
434 
435     for (auto &v : triangle)
436     {
437         geometry->addVertex(v);
438     }
439 
440     bottomLevelAS->addGeometry(geometry);
441     bottomLevelAS->setBuildFlags(VK_BUILD_ACCELERATION_STRUCTURE_ALLOW_DATA_ACCESS_KHR);
442     bottomLevelAS->setBuildType(m_params.buildType);
443     bottomLevelAS->createAndBuild(vkd, device, cmdBuffer, alloc);
444     de::SharedPtr<BottomLevelAccelerationStructure> blasSharedPtr(bottomLevelAS.release());
445 
446     topLevelAS->setInstanceCount(1);
447     topLevelAS->setBuildType(m_params.buildType);
448     topLevelAS->addInstance(blasSharedPtr, (m_params.testFlagMask & TEST_FLAG_BIT_INSTANCE_TRANSFORM) ?
449                                                notQuiteIdentityMatrix3x4 :
450                                                identityMatrix3x4);
451     topLevelAS->createAndBuild(vkd, device, cmdBuffer, alloc);
452 
453     // One ray for this test
454     // XXX Should it be multiple triangles and one ray per triangle for more coverage?
455     // XXX If it's really one ray, the origin buffer is complete overkill
456     uint32_t numRays = 1; // XXX
457 
458     // SSBO buffer for origins.
459     const auto originsBufferSize = static_cast<VkDeviceSize>(sizeof(tcu::Vec4) * numRays);
460     const auto originsBufferInfo = makeBufferCreateInfo(originsBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
461     BufferWithMemory originsBuffer(vkd, device, alloc, originsBufferInfo, MemoryRequirement::HostVisible);
462     auto &originsBufferAlloc = originsBuffer.getAllocation();
463     void *originsBufferData  = originsBufferAlloc.getHostPtr();
464 
465     std::vector<tcu::Vec4> origins;
466     std::vector<tcu::Vec3> expectedOutputPositions;
467     origins.reserve(numRays);
468     expectedOutputPositions.reserve(3 * numRays);
469 
470     // Fill in vector of expected outputs
471     for (uint32_t index = 0; index < numRays; index++)
472     {
473         for (uint32_t vert = 0; vert < 3; vert++)
474         {
475             tcu::Vec3 pos = triangle[vert];
476 
477             expectedOutputPositions.push_back(pos);
478         }
479     }
480 
481     // XXX Arbitrary location and see above
482     for (uint32_t index = 0; index < numRays; index++)
483     {
484         origins.push_back(tcu::Vec4(0.25, 0.25, 1.0, 0.0));
485     }
486 
487     const auto originsBufferSizeSz = static_cast<size_t>(originsBufferSize);
488     deMemcpy(originsBufferData, origins.data(), originsBufferSizeSz);
489     flushAlloc(vkd, device, originsBufferAlloc);
490 
491     // Storage buffer for output modes
492     const auto outputPositionsBufferSize = static_cast<VkDeviceSize>(3 * 4 * sizeof(float) * numRays);
493     const auto outputPositionsBufferInfo =
494         makeBufferCreateInfo(outputPositionsBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
495     BufferWithMemory outputPositionsBuffer(vkd, device, alloc, outputPositionsBufferInfo,
496                                            MemoryRequirement::HostVisible);
497     auto &outputPositionsBufferAlloc = outputPositionsBuffer.getAllocation();
498     void *outputPositionsBufferData  = outputPositionsBufferAlloc.getHostPtr();
499     deMemset(outputPositionsBufferData, 0xFF, static_cast<size_t>(outputPositionsBufferSize));
500     flushAlloc(vkd, device, outputPositionsBufferAlloc);
501 
502     // Descriptor set layout.
503     DescriptorSetLayoutBuilder dsLayoutBuilder;
504     dsLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, VK_SHADER_STAGE_ALL);
505     dsLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, VK_SHADER_STAGE_ALL);
506     dsLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, VK_SHADER_STAGE_ALL);
507     const auto setLayout = dsLayoutBuilder.build(vkd, device);
508 
509     // Pipeline layout.
510     const auto pipelineLayout = makePipelineLayout(vkd, device, setLayout.get());
511 
512     // Descriptor pool and set.
513     DescriptorPoolBuilder poolBuilder;
514     poolBuilder.addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR);
515     poolBuilder.addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
516     poolBuilder.addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
517     const auto descriptorPool = poolBuilder.build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
518     const auto descriptorSet  = makeDescriptorSet(vkd, device, descriptorPool.get(), setLayout.get());
519 
520     // Update descriptor set.
521     {
522         const VkWriteDescriptorSetAccelerationStructureKHR accelDescInfo = {
523             VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR,
524             nullptr,
525             1u,
526             topLevelAS.get()->getPtr(),
527         };
528         const auto inStorageBufferInfo = makeDescriptorBufferInfo(originsBuffer.get(), 0ull, VK_WHOLE_SIZE);
529         const auto storageBufferInfo   = makeDescriptorBufferInfo(outputPositionsBuffer.get(), 0ull, VK_WHOLE_SIZE);
530 
531         DescriptorSetUpdateBuilder updateBuilder;
532         updateBuilder.writeSingle(descriptorSet.get(), DescriptorSetUpdateBuilder::Location::binding(0u),
533                                   VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &accelDescInfo);
534         updateBuilder.writeSingle(descriptorSet.get(), DescriptorSetUpdateBuilder::Location::binding(1u),
535                                   VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &inStorageBufferInfo);
536         updateBuilder.writeSingle(descriptorSet.get(), DescriptorSetUpdateBuilder::Location::binding(2u),
537                                   VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &storageBufferInfo);
538         updateBuilder.update(vkd, device);
539     }
540 
541     Move<VkPipeline> pipeline;
542     de::MovePtr<BufferWithMemory> raygenSBT;
543     Move<VkRenderPass> renderPass;
544     Move<VkFramebuffer> framebuffer;
545 
546     if (m_params.shaderSourceType == SST_VERTEX_SHADER)
547     {
548         auto vertexModule = createShaderModule(vkd, device, m_context.getBinaryCollection().get("vert"), 0);
549 
550         const uint32_t width  = 32u;
551         const uint32_t height = 32u;
552         renderPass            = makeEmptyRenderPass(vkd, device);
553         framebuffer           = makeFramebuffer(vkd, device, *renderPass, width, height);
554         pipeline              = makeGraphicsPipeline(vkd, device, *pipelineLayout, *renderPass, *vertexModule, 0);
555 
556         const VkRenderPassBeginInfo renderPassBeginInfo = {
557             VK_STRUCTURE_TYPE_RENDER_PASS_BEGIN_INFO, // VkStructureType sType;
558             DE_NULL,                                  // const void* pNext;
559             *renderPass,                              // VkRenderPass renderPass;
560             *framebuffer,                             // VkFramebuffer framebuffer;
561             makeRect2D(width, height),                // VkRect2D renderArea;
562             0u,                                       // uint32_t clearValueCount;
563             DE_NULL                                   // const VkClearValue* pClearValues;
564         };
565 
566         vkd.cmdBeginRenderPass(cmdBuffer, &renderPassBeginInfo, VK_SUBPASS_CONTENTS_INLINE);
567         vkd.cmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, pipeline.get());
568         vkd.cmdBindDescriptorSets(cmdBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, pipelineLayout.get(), 0u, 1u,
569                                   &descriptorSet.get(), 0u, nullptr);
570         vkd.cmdDraw(cmdBuffer, kNumThreadsAtOnce, 1, 0, 0);
571         vkd.cmdEndRenderPass(cmdBuffer);
572     }
573     else if (m_params.shaderSourceType == SST_RAY_GENERATION_SHADER)
574     {
575         const auto &vki    = m_context.getInstanceInterface();
576         const auto physDev = m_context.getPhysicalDevice();
577 
578         // Shader module.
579         auto rgenModule = createShaderModule(vkd, device, m_context.getBinaryCollection().get("rgen"), 0);
580 
581         // Get some ray tracing properties.
582         uint32_t shaderGroupHandleSize    = 0u;
583         uint32_t shaderGroupBaseAlignment = 1u;
584         {
585             const auto rayTracingPropertiesKHR = makeRayTracingProperties(vki, physDev);
586             shaderGroupHandleSize              = rayTracingPropertiesKHR->getShaderGroupHandleSize();
587             shaderGroupBaseAlignment           = rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
588         }
589 
590         auto raygenSBTRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
591         auto unusedSBTRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
592 
593         {
594             const auto rayTracingPipeline = de::newMovePtr<RayTracingPipeline>();
595             rayTracingPipeline->setCreateFlags(VK_PIPELINE_CREATE_RAY_TRACING_OPACITY_MICROMAP_BIT_EXT);
596             rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR, rgenModule, 0);
597 
598             pipeline = rayTracingPipeline->createPipeline(vkd, device, pipelineLayout.get());
599 
600             raygenSBT = rayTracingPipeline->createShaderBindingTable(
601                 vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
602             raygenSBTRegion = makeStridedDeviceAddressRegionKHR(
603                 getBufferDeviceAddress(vkd, device, raygenSBT->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
604         }
605 
606         // Trace rays.
607         vkd.cmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipeline.get());
608         vkd.cmdBindDescriptorSets(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipelineLayout.get(), 0u, 1u,
609                                   &descriptorSet.get(), 0u, nullptr);
610         vkd.cmdTraceRaysKHR(cmdBuffer, &raygenSBTRegion, &unusedSBTRegion, &unusedSBTRegion, &unusedSBTRegion,
611                             kNumThreadsAtOnce, 1u, 1u);
612     }
613     else
614     {
615         DE_ASSERT(m_params.shaderSourceType == SST_COMPUTE_SHADER);
616         // Shader module.
617         const auto compModule = createShaderModule(vkd, device, m_context.getBinaryCollection().get("comp"), 0);
618 
619         // Pipeline.
620         const VkPipelineShaderStageCreateInfo shaderInfo = {
621             VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // VkStructureType sType;
622             nullptr,                                             // const void* pNext;
623             0u,                                                  // VkPipelineShaderStageCreateFlags flags;
624             VK_SHADER_STAGE_COMPUTE_BIT,                         // VkShaderStageFlagBits stage;
625             compModule.get(),                                    // VkShaderModule module;
626             "main",                                              // const char* pName;
627             nullptr,                                             // const VkSpecializationInfo* pSpecializationInfo;
628         };
629         const VkComputePipelineCreateInfo pipelineInfo = {
630             VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, // VkStructureType sType;
631             nullptr,                                        // const void* pNext;
632             0u,                                             // VkPipelineCreateFlags flags;
633             shaderInfo,                                     // VkPipelineShaderStageCreateInfo stage;
634             pipelineLayout.get(),                           // VkPipelineLayout layout;
635             DE_NULL,                                        // VkPipeline basePipelineHandle;
636             0,                                              // int32_t basePipelineIndex;
637         };
638         pipeline = createComputePipeline(vkd, device, DE_NULL, &pipelineInfo);
639 
640         // Dispatch work with ray queries.
641         vkd.cmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline.get());
642         vkd.cmdBindDescriptorSets(cmdBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipelineLayout.get(), 0u, 1u,
643                                   &descriptorSet.get(), 0u, nullptr);
644         vkd.cmdDispatch(cmdBuffer, 1u, 1u, 1u);
645     }
646 
647     // Barrier for the output buffer.
648     const auto bufferBarrier = makeMemoryBarrier(VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
649     vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, VK_PIPELINE_STAGE_HOST_BIT, 0u, 1u,
650                            &bufferBarrier, 0u, nullptr, 0u, nullptr);
651 
652     endCommandBuffer(vkd, cmdBuffer);
653     submitCommandsAndWait(vkd, device, queue, cmdBuffer);
654 
655     // Verify results.
656     std::vector<tcu::Vec4> outputData(expectedOutputPositions.size());
657     const auto outputPositionsBufferSizeSz = static_cast<size_t>(outputPositionsBufferSize);
658 
659     invalidateAlloc(vkd, device, outputPositionsBufferAlloc);
660     DE_ASSERT(de::dataSize(outputData) == outputPositionsBufferSizeSz);
661     deMemcpy(outputData.data(), outputPositionsBufferData, outputPositionsBufferSizeSz);
662 
663     for (size_t i = 0; i < outputData.size(); ++i)
664     {
665         /*const */ auto &outVal = outputData[i]; // Should be const but .xyz() isn't
666         tcu::Vec3 outVec3       = outVal.xyz();
667         const auto &expectedVal = expectedOutputPositions[i];
668         const auto &diff        = expectedOutputPositions[i] - outVec3;
669         float len               = dot(diff, diff);
670 
671         // XXX Find a better epsilon
672         if (!(len < 1e-5))
673         {
674             std::ostringstream msg;
675             msg << "Unexpected value found for element " << i << ": expected " << expectedVal << " and found " << outVal
676                 << ";";
677             TCU_FAIL(msg.str());
678         }
679 #if 0
680         else
681         {
682             std::ostringstream msg;
683             msg << "Expected value found for element " << i << ": expected " << expectedVal << " and found " << outVal << ";\n";
684             std::cout << msg.str();
685         }
686 #endif
687     }
688 
689     return tcu::TestStatus::pass("Pass");
690 }
691 
692 } // namespace
693 
createPositionFetchTests(tcu::TestContext & testCtx)694 tcu::TestCaseGroup *createPositionFetchTests(tcu::TestContext &testCtx)
695 {
696     de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "position_fetch"));
697 
698     struct
699     {
700         vk::VkAccelerationStructureBuildTypeKHR buildType;
701         const char *name;
702     } buildTypes[] = {
703         {VK_ACCELERATION_STRUCTURE_BUILD_TYPE_HOST_KHR, "cpu_built"},
704         {VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR, "gpu_built"},
705     };
706 
707     const struct
708     {
709         ShaderSourceType shaderSourceType;
710         ShaderSourcePipeline shaderSourcePipeline;
711         std::string name;
712     } shaderSourceTypes[] = {
713         {SST_VERTEX_SHADER, SSP_GRAPHICS_PIPELINE, "vertex_shader"},
714         {
715             SST_COMPUTE_SHADER,
716             SSP_COMPUTE_PIPELINE,
717             "compute_shader",
718         },
719         {
720             SST_RAY_GENERATION_SHADER,
721             SSP_RAY_TRACING_PIPELINE,
722             "rgen_shader",
723         },
724     };
725 
726     const VkFormat vertexFormats[] = {
727         // Mandatory formats.
728         VK_FORMAT_R32G32_SFLOAT,
729         VK_FORMAT_R32G32B32_SFLOAT,
730         VK_FORMAT_R16G16_SFLOAT,
731         VK_FORMAT_R16G16B16A16_SFLOAT,
732         VK_FORMAT_R16G16_SNORM,
733         VK_FORMAT_R16G16B16A16_SNORM,
734 
735         // Additional formats.
736         VK_FORMAT_R8G8_SNORM,
737         VK_FORMAT_R8G8B8_SNORM,
738         VK_FORMAT_R8G8B8A8_SNORM,
739         VK_FORMAT_R16G16B16_SNORM,
740         VK_FORMAT_R16G16B16_SFLOAT,
741         VK_FORMAT_R32G32B32A32_SFLOAT,
742         VK_FORMAT_R64G64_SFLOAT,
743         VK_FORMAT_R64G64B64_SFLOAT,
744         VK_FORMAT_R64G64B64A64_SFLOAT,
745     };
746 
747     for (size_t shaderSourceNdx = 0; shaderSourceNdx < DE_LENGTH_OF_ARRAY(shaderSourceTypes); ++shaderSourceNdx)
748     {
749         de::MovePtr<tcu::TestCaseGroup> sourceTypeGroup(
750             new tcu::TestCaseGroup(group->getTestContext(), shaderSourceTypes[shaderSourceNdx].name.c_str()));
751 
752         for (size_t buildTypeNdx = 0; buildTypeNdx < DE_LENGTH_OF_ARRAY(buildTypes); ++buildTypeNdx)
753         {
754             de::MovePtr<tcu::TestCaseGroup> buildGroup(
755                 new tcu::TestCaseGroup(group->getTestContext(), buildTypes[buildTypeNdx].name));
756 
757             for (size_t vertexFormatNdx = 0; vertexFormatNdx < DE_LENGTH_OF_ARRAY(vertexFormats); ++vertexFormatNdx)
758             {
759                 const auto format     = vertexFormats[vertexFormatNdx];
760                 const auto formatName = getFormatSimpleName(format);
761 
762                 de::MovePtr<tcu::TestCaseGroup> vertexFormatGroup(
763                     new tcu::TestCaseGroup(group->getTestContext(), formatName.c_str()));
764 
765                 for (uint32_t testFlagMask = 0; testFlagMask < TEST_FLAG_BIT_LAST; testFlagMask++)
766                 {
767                     std::string maskName = "";
768 
769                     for (uint32_t bit = 0; bit < testFlagBitNames.size(); bit++)
770                     {
771                         if (testFlagMask & (1 << bit))
772                         {
773                             if (maskName != "")
774                                 maskName += "_";
775                             maskName += testFlagBitNames[bit];
776                         }
777                     }
778                     if (maskName == "")
779                         maskName = "NoFlags";
780 
781                     de::MovePtr<tcu::TestCaseGroup> testFlagGroup(
782                         new tcu::TestCaseGroup(group->getTestContext(), maskName.c_str()));
783 
784                     TestParams testParams{
785                         shaderSourceTypes[shaderSourceNdx].shaderSourceType,
786                         shaderSourceTypes[shaderSourceNdx].shaderSourcePipeline,
787                         buildTypes[buildTypeNdx].buildType,
788                         format,
789                         testFlagMask,
790                     };
791 
792                     vertexFormatGroup->addChild(new PositionFetchCase(testCtx, maskName, testParams));
793                 }
794                 buildGroup->addChild(vertexFormatGroup.release());
795             }
796             sourceTypeGroup->addChild(buildGroup.release());
797         }
798         group->addChild(sourceTypeGroup.release());
799     }
800 
801     return group.release();
802 }
803 } // namespace RayQuery
804 } // namespace vkt
805