1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2020 The Khronos Group Inc.
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  *      http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  *
19  *//*!
20  * \file
21  * \brief Ray Tracing Callable Shader tests
22  *//*--------------------------------------------------------------------*/
23 
24 #include "vktRayTracingCallableShadersTests.hpp"
25 
26 #include "vkDefs.hpp"
27 
28 #include "vktTestCase.hpp"
29 #include "vktTestGroupUtil.hpp"
30 #include "vkCmdUtil.hpp"
31 #include "vkObjUtil.hpp"
32 #include "vkBuilderUtil.hpp"
33 #include "vkBarrierUtil.hpp"
34 #include "vkBufferWithMemory.hpp"
35 #include "vkImageWithMemory.hpp"
36 #include "vkTypeUtil.hpp"
37 #include "vkImageUtil.hpp"
38 #include "deRandom.hpp"
39 #include "tcuSurface.hpp"
40 #include "tcuTexture.hpp"
41 #include "tcuTextureUtil.hpp"
42 #include "tcuTestLog.hpp"
43 #include "tcuImageCompare.hpp"
44 
45 #include "vkRayTracingUtil.hpp"
46 
47 namespace vkt
48 {
49 namespace RayTracing
50 {
51 namespace
52 {
53 using namespace vk;
54 using namespace vkt;
55 
56 static const VkFlags ALL_RAY_TRACING_STAGES = VK_SHADER_STAGE_RAYGEN_BIT_KHR | VK_SHADER_STAGE_ANY_HIT_BIT_KHR |
57                                               VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_MISS_BIT_KHR |
58                                               VK_SHADER_STAGE_INTERSECTION_BIT_KHR | VK_SHADER_STAGE_CALLABLE_BIT_KHR;
59 
60 enum CallableShaderTestType
61 {
62     CSTT_RGEN_CALL      = 0,
63     CSTT_RGEN_CALL_CALL = 1,
64     CSTT_HIT_CALL       = 2,
65     CSTT_RGEN_MULTICALL = 3,
66     CSTT_COUNT
67 };
68 
69 const uint32_t TEST_WIDTH  = 8;
70 const uint32_t TEST_HEIGHT = 8;
71 
72 struct TestParams;
73 
74 class TestConfiguration
75 {
76 public:
~TestConfiguration()77     virtual ~TestConfiguration()
78     {
79     }
80 
81     virtual std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> initBottomAccelerationStructures(
82         Context &context, TestParams &testParams) = 0;
83     virtual de::MovePtr<TopLevelAccelerationStructure> initTopAccelerationStructure(
84         Context &context, TestParams &testParams,
85         std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> &bottomLevelAccelerationStructures)    = 0;
86     virtual void initRayTracingShaders(de::MovePtr<RayTracingPipeline> &rayTracingPipeline, Context &context,
87                                        TestParams &testParams)                                              = 0;
88     virtual void initShaderBindingTables(de::MovePtr<RayTracingPipeline> &rayTracingPipeline, Context &context,
89                                          TestParams &testParams, VkPipeline pipeline, uint32_t shaderGroupHandleSize,
90                                          uint32_t shaderGroupBaseAlignment,
91                                          de::MovePtr<BufferWithMemory> &raygenShaderBindingTable,
92                                          de::MovePtr<BufferWithMemory> &hitShaderBindingTable,
93                                          de::MovePtr<BufferWithMemory> &missShaderBindingTable,
94                                          de::MovePtr<BufferWithMemory> &callableShaderBindingTable,
95                                          VkStridedDeviceAddressRegionKHR &raygenShaderBindingTableRegion,
96                                          VkStridedDeviceAddressRegionKHR &hitShaderBindingTableRegion,
97                                          VkStridedDeviceAddressRegionKHR &missShaderBindingTableRegion,
98                                          VkStridedDeviceAddressRegionKHR &callableShaderBindingTableRegion) = 0;
99     virtual bool verifyImage(BufferWithMemory *resultBuffer, Context &context, TestParams &testParams)      = 0;
100     virtual VkFormat getResultImageFormat()                                                                 = 0;
101     virtual size_t getResultImageFormatSize()                                                               = 0;
102     virtual VkClearValue getClearValue()                                                                    = 0;
103 };
104 
105 struct TestParams
106 {
107     uint32_t width;
108     uint32_t height;
109     CallableShaderTestType callableShaderTestType;
110     de::SharedPtr<TestConfiguration> testConfiguration;
111     glu::ShaderType invokingShader;
112     bool multipleInvocations;
113 };
114 
getShaderGroupHandleSize(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)115 uint32_t getShaderGroupHandleSize(const InstanceInterface &vki, const VkPhysicalDevice physicalDevice)
116 {
117     de::MovePtr<RayTracingProperties> rayTracingPropertiesKHR;
118 
119     rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
120     return rayTracingPropertiesKHR->getShaderGroupHandleSize();
121 }
122 
getShaderGroupBaseAlignment(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)123 uint32_t getShaderGroupBaseAlignment(const InstanceInterface &vki, const VkPhysicalDevice physicalDevice)
124 {
125     de::MovePtr<RayTracingProperties> rayTracingPropertiesKHR;
126 
127     rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
128     return rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
129 }
130 
makeImageCreateInfo(uint32_t width,uint32_t height,VkFormat format)131 VkImageCreateInfo makeImageCreateInfo(uint32_t width, uint32_t height, VkFormat format)
132 {
133     const VkImageCreateInfo imageCreateInfo = {
134         VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // VkStructureType sType;
135         DE_NULL,                             // const void* pNext;
136         (VkImageCreateFlags)0u,              // VkImageCreateFlags flags;
137         VK_IMAGE_TYPE_2D,                    // VkImageType imageType;
138         format,                              // VkFormat format;
139         makeExtent3D(width, height, 1),      // VkExtent3D extent;
140         1u,                                  // uint32_t mipLevels;
141         1u,                                  // uint32_t arrayLayers;
142         VK_SAMPLE_COUNT_1_BIT,               // VkSampleCountFlagBits samples;
143         VK_IMAGE_TILING_OPTIMAL,             // VkImageTiling tiling;
144         VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_TRANSFER_SRC_BIT |
145             VK_IMAGE_USAGE_TRANSFER_DST_BIT, // VkImageUsageFlags usage;
146         VK_SHARING_MODE_EXCLUSIVE,           // VkSharingMode sharingMode;
147         0u,                                  // uint32_t queueFamilyIndexCount;
148         DE_NULL,                             // const uint32_t* pQueueFamilyIndices;
149         VK_IMAGE_LAYOUT_UNDEFINED            // VkImageLayout initialLayout;
150     };
151 
152     return imageCreateInfo;
153 }
154 
155 class SingleSquareConfiguration : public TestConfiguration
156 {
157 public:
158     std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> initBottomAccelerationStructures(
159         Context &context, TestParams &testParams) override;
160     de::MovePtr<TopLevelAccelerationStructure> initTopAccelerationStructure(
161         Context &context, TestParams &testParams,
162         std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> &bottomLevelAccelerationStructures) override;
163     void initRayTracingShaders(de::MovePtr<RayTracingPipeline> &rayTracingPipeline, Context &context,
164                                TestParams &testParams) override;
165     void initShaderBindingTables(de::MovePtr<RayTracingPipeline> &rayTracingPipeline, Context &context,
166                                  TestParams &testParams, VkPipeline pipeline, uint32_t shaderGroupHandleSize,
167                                  uint32_t shaderGroupBaseAlignment,
168                                  de::MovePtr<BufferWithMemory> &raygenShaderBindingTable,
169                                  de::MovePtr<BufferWithMemory> &hitShaderBindingTable,
170                                  de::MovePtr<BufferWithMemory> &missShaderBindingTable,
171                                  de::MovePtr<BufferWithMemory> &callableShaderBindingTable,
172                                  VkStridedDeviceAddressRegionKHR &raygenShaderBindingTableRegion,
173                                  VkStridedDeviceAddressRegionKHR &hitShaderBindingTableRegion,
174                                  VkStridedDeviceAddressRegionKHR &missShaderBindingTableRegion,
175                                  VkStridedDeviceAddressRegionKHR &callableShaderBindingTableRegion) override;
176     bool verifyImage(BufferWithMemory *resultBuffer, Context &context, TestParams &testParams) override;
177     VkFormat getResultImageFormat() override;
178     size_t getResultImageFormatSize() override;
179     VkClearValue getClearValue() override;
180 };
181 
182 std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> SingleSquareConfiguration::
initBottomAccelerationStructures(Context & context,TestParams & testParams)183     initBottomAccelerationStructures(Context &context, TestParams &testParams)
184 {
185     DE_UNREF(context);
186 
187     tcu::Vec3 v0(1.0, float(testParams.height) - 1.0f, 0.0);
188     tcu::Vec3 v1(1.0, 1.0, 0.0);
189     tcu::Vec3 v2(float(testParams.width) - 1.0f, float(testParams.height) - 1.0f, 0.0);
190     tcu::Vec3 v3(float(testParams.width) - 1.0f, 1.0, 0.0);
191 
192     std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> result;
193     de::MovePtr<BottomLevelAccelerationStructure> bottomLevelAccelerationStructure =
194         makeBottomLevelAccelerationStructure();
195     bottomLevelAccelerationStructure->setGeometryCount(1);
196 
197     de::SharedPtr<RaytracedGeometryBase> geometry =
198         makeRaytracedGeometry(VK_GEOMETRY_TYPE_TRIANGLES_KHR, VK_FORMAT_R32G32B32_SFLOAT, VK_INDEX_TYPE_NONE_KHR);
199     geometry->addVertex(v0);
200     geometry->addVertex(v1);
201     geometry->addVertex(v2);
202     geometry->addVertex(v2);
203     geometry->addVertex(v1);
204     geometry->addVertex(v3);
205     bottomLevelAccelerationStructure->addGeometry(geometry);
206 
207     result.push_back(de::SharedPtr<BottomLevelAccelerationStructure>(bottomLevelAccelerationStructure.release()));
208 
209     return result;
210 }
211 
initTopAccelerationStructure(Context & context,TestParams & testParams,std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> & bottomLevelAccelerationStructures)212 de::MovePtr<TopLevelAccelerationStructure> SingleSquareConfiguration::initTopAccelerationStructure(
213     Context &context, TestParams &testParams,
214     std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> &bottomLevelAccelerationStructures)
215 {
216     DE_UNREF(context);
217     DE_UNREF(testParams);
218 
219     de::MovePtr<TopLevelAccelerationStructure> result = makeTopLevelAccelerationStructure();
220     result->setInstanceCount(1);
221     result->addInstance(bottomLevelAccelerationStructures[0]);
222 
223     return result;
224 }
225 
initRayTracingShaders(de::MovePtr<RayTracingPipeline> & rayTracingPipeline,Context & context,TestParams & testParams)226 void SingleSquareConfiguration::initRayTracingShaders(de::MovePtr<RayTracingPipeline> &rayTracingPipeline,
227                                                       Context &context, TestParams &testParams)
228 {
229     const DeviceInterface &vkd = context.getDeviceInterface();
230     const VkDevice device      = context.getDevice();
231 
232     switch (testParams.callableShaderTestType)
233     {
234     case CSTT_RGEN_CALL:
235     {
236         rayTracingPipeline->addShader(
237             VK_SHADER_STAGE_RAYGEN_BIT_KHR,
238             createShaderModule(vkd, device, context.getBinaryCollection().get("rgen_call"), 0), 0);
239         rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,
240                                       createShaderModule(vkd, device, context.getBinaryCollection().get("chit"), 0), 1);
241         rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR,
242                                       createShaderModule(vkd, device, context.getBinaryCollection().get("miss"), 0), 2);
243         rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,
244                                       createShaderModule(vkd, device, context.getBinaryCollection().get("call_0"), 0),
245                                       3);
246         break;
247     }
248     case CSTT_RGEN_CALL_CALL:
249     {
250         rayTracingPipeline->addShader(
251             VK_SHADER_STAGE_RAYGEN_BIT_KHR,
252             createShaderModule(vkd, device, context.getBinaryCollection().get("rgen_call"), 0), 0);
253         rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,
254                                       createShaderModule(vkd, device, context.getBinaryCollection().get("chit"), 0), 1);
255         rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR,
256                                       createShaderModule(vkd, device, context.getBinaryCollection().get("miss"), 0), 2);
257         rayTracingPipeline->addShader(
258             VK_SHADER_STAGE_CALLABLE_BIT_KHR,
259             createShaderModule(vkd, device, context.getBinaryCollection().get("call_call"), 0), 3);
260         rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,
261                                       createShaderModule(vkd, device, context.getBinaryCollection().get("call_0"), 0),
262                                       4);
263         break;
264     }
265     case CSTT_HIT_CALL:
266     {
267         rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR,
268                                       createShaderModule(vkd, device, context.getBinaryCollection().get("rgen"), 0), 0);
269         rayTracingPipeline->addShader(
270             VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,
271             createShaderModule(vkd, device, context.getBinaryCollection().get("chit_call"), 0), 1);
272         rayTracingPipeline->addShader(
273             VK_SHADER_STAGE_MISS_BIT_KHR,
274             createShaderModule(vkd, device, context.getBinaryCollection().get("miss_call"), 0), 2);
275         rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,
276                                       createShaderModule(vkd, device, context.getBinaryCollection().get("call_0"), 0),
277                                       3);
278         break;
279     }
280     case CSTT_RGEN_MULTICALL:
281     {
282         rayTracingPipeline->addShader(
283             VK_SHADER_STAGE_RAYGEN_BIT_KHR,
284             createShaderModule(vkd, device, context.getBinaryCollection().get("rgen_multicall"), 0), 0);
285         rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,
286                                       createShaderModule(vkd, device, context.getBinaryCollection().get("chit"), 0), 1);
287         rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR,
288                                       createShaderModule(vkd, device, context.getBinaryCollection().get("miss"), 0), 2);
289         rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,
290                                       createShaderModule(vkd, device, context.getBinaryCollection().get("call_0"), 0),
291                                       3);
292         rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,
293                                       createShaderModule(vkd, device, context.getBinaryCollection().get("call_1"), 0),
294                                       4);
295         rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,
296                                       createShaderModule(vkd, device, context.getBinaryCollection().get("call_2"), 0),
297                                       5);
298         rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,
299                                       createShaderModule(vkd, device, context.getBinaryCollection().get("call_3"), 0),
300                                       6);
301         break;
302     }
303     default:
304         TCU_THROW(InternalError, "Wrong shader test type");
305     }
306 }
307 
initShaderBindingTables(de::MovePtr<RayTracingPipeline> & rayTracingPipeline,Context & context,TestParams & testParams,VkPipeline pipeline,uint32_t shaderGroupHandleSize,uint32_t shaderGroupBaseAlignment,de::MovePtr<BufferWithMemory> & raygenShaderBindingTable,de::MovePtr<BufferWithMemory> & hitShaderBindingTable,de::MovePtr<BufferWithMemory> & missShaderBindingTable,de::MovePtr<BufferWithMemory> & callableShaderBindingTable,VkStridedDeviceAddressRegionKHR & raygenShaderBindingTableRegion,VkStridedDeviceAddressRegionKHR & hitShaderBindingTableRegion,VkStridedDeviceAddressRegionKHR & missShaderBindingTableRegion,VkStridedDeviceAddressRegionKHR & callableShaderBindingTableRegion)308 void SingleSquareConfiguration::initShaderBindingTables(
309     de::MovePtr<RayTracingPipeline> &rayTracingPipeline, Context &context, TestParams &testParams, VkPipeline pipeline,
310     uint32_t shaderGroupHandleSize, uint32_t shaderGroupBaseAlignment,
311     de::MovePtr<BufferWithMemory> &raygenShaderBindingTable, de::MovePtr<BufferWithMemory> &hitShaderBindingTable,
312     de::MovePtr<BufferWithMemory> &missShaderBindingTable, de::MovePtr<BufferWithMemory> &callableShaderBindingTable,
313     VkStridedDeviceAddressRegionKHR &raygenShaderBindingTableRegion,
314     VkStridedDeviceAddressRegionKHR &hitShaderBindingTableRegion,
315     VkStridedDeviceAddressRegionKHR &missShaderBindingTableRegion,
316     VkStridedDeviceAddressRegionKHR &callableShaderBindingTableRegion)
317 {
318     const DeviceInterface &vkd = context.getDeviceInterface();
319     const VkDevice device      = context.getDevice();
320     Allocator &allocator       = context.getDefaultAllocator();
321 
322     switch (testParams.callableShaderTestType)
323     {
324     case CSTT_RGEN_CALL:
325     {
326         raygenShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
327             vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
328         hitShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
329             vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
330         missShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
331             vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
332         callableShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
333             vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 3, 1);
334 
335         raygenShaderBindingTableRegion =
336             makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0),
337                                               shaderGroupHandleSize, shaderGroupHandleSize);
338         hitShaderBindingTableRegion =
339             makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0),
340                                               shaderGroupHandleSize, shaderGroupHandleSize);
341         missShaderBindingTableRegion =
342             makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missShaderBindingTable->get(), 0),
343                                               shaderGroupHandleSize, shaderGroupHandleSize);
344         callableShaderBindingTableRegion =
345             makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0),
346                                               shaderGroupHandleSize, shaderGroupHandleSize);
347         break;
348     }
349     case CSTT_RGEN_CALL_CALL:
350     {
351         raygenShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
352             vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
353         hitShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
354             vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
355         missShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
356             vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
357         callableShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
358             vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 3, 2);
359 
360         raygenShaderBindingTableRegion =
361             makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0),
362                                               shaderGroupHandleSize, shaderGroupHandleSize);
363         hitShaderBindingTableRegion =
364             makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0),
365                                               shaderGroupHandleSize, shaderGroupHandleSize);
366         missShaderBindingTableRegion =
367             makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missShaderBindingTable->get(), 0),
368                                               shaderGroupHandleSize, shaderGroupHandleSize);
369         callableShaderBindingTableRegion =
370             makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0),
371                                               shaderGroupHandleSize, 2 * shaderGroupHandleSize);
372         break;
373     }
374     case CSTT_HIT_CALL:
375     {
376         raygenShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
377             vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
378         hitShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
379             vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
380         missShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
381             vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
382         callableShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
383             vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 3, 1);
384 
385         raygenShaderBindingTableRegion =
386             makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0),
387                                               shaderGroupHandleSize, shaderGroupHandleSize);
388         hitShaderBindingTableRegion =
389             makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0),
390                                               shaderGroupHandleSize, shaderGroupHandleSize);
391         missShaderBindingTableRegion =
392             makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missShaderBindingTable->get(), 0),
393                                               shaderGroupHandleSize, shaderGroupHandleSize);
394         callableShaderBindingTableRegion =
395             makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0),
396                                               shaderGroupHandleSize, shaderGroupHandleSize);
397         break;
398     }
399     case CSTT_RGEN_MULTICALL:
400     {
401         raygenShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
402             vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
403         hitShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
404             vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
405         missShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
406             vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
407         callableShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
408             vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 3, 4);
409 
410         raygenShaderBindingTableRegion =
411             makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0),
412                                               shaderGroupHandleSize, shaderGroupHandleSize);
413         hitShaderBindingTableRegion =
414             makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0),
415                                               shaderGroupHandleSize, shaderGroupHandleSize);
416         missShaderBindingTableRegion =
417             makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missShaderBindingTable->get(), 0),
418                                               shaderGroupHandleSize, shaderGroupHandleSize);
419         callableShaderBindingTableRegion =
420             makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0),
421                                               shaderGroupHandleSize, 4 * shaderGroupHandleSize);
422         break;
423     }
424     default:
425         TCU_THROW(InternalError, "Wrong shader test type");
426     }
427 }
428 
verifyImage(BufferWithMemory * resultBuffer,Context & context,TestParams & testParams)429 bool SingleSquareConfiguration::verifyImage(BufferWithMemory *resultBuffer, Context &context, TestParams &testParams)
430 {
431     // create result image
432     tcu::TextureFormat imageFormat = vk::mapVkFormat(getResultImageFormat());
433     tcu::ConstPixelBufferAccess resultAccess(imageFormat, testParams.width, testParams.height, 1,
434                                              resultBuffer->getAllocation().getHostPtr());
435 
436     // create reference image
437     std::vector<uint32_t> reference(testParams.width * testParams.height);
438     tcu::PixelBufferAccess referenceAccess(imageFormat, testParams.width, testParams.height, 1, reference.data());
439 
440     tcu::UVec4 missValue, hitValue;
441 
442     // clear reference image with hit and miss values ( hit works only for tests calling traceRayEXT in rgen shader )
443     switch (testParams.callableShaderTestType)
444     {
445     case CSTT_RGEN_CALL:
446         missValue = tcu::UVec4(1, 0, 0, 0);
447         hitValue  = tcu::UVec4(1, 0, 0, 0);
448         break;
449     case CSTT_RGEN_CALL_CALL:
450         missValue = tcu::UVec4(1, 0, 0, 0);
451         hitValue  = tcu::UVec4(1, 0, 0, 0);
452         break;
453     case CSTT_HIT_CALL:
454         missValue = tcu::UVec4(1, 0, 0, 0);
455         hitValue  = tcu::UVec4(2, 0, 0, 0);
456         break;
457     case CSTT_RGEN_MULTICALL:
458         missValue = tcu::UVec4(16, 0, 0, 0);
459         hitValue  = tcu::UVec4(16, 0, 0, 0);
460         break;
461     default:
462         TCU_THROW(InternalError, "Wrong shader test type");
463     }
464 
465     tcu::clear(referenceAccess, missValue);
466     for (uint32_t y = 1; y < testParams.width - 1; ++y)
467         for (uint32_t x = 1; x < testParams.height - 1; ++x)
468             referenceAccess.setPixel(hitValue, x, y);
469 
470     // compare result and reference
471     return tcu::intThresholdCompare(context.getTestContext().getLog(), "Result comparison", "", referenceAccess,
472                                     resultAccess, tcu::UVec4(0), tcu::COMPARE_LOG_RESULT);
473 }
474 
getResultImageFormat()475 VkFormat SingleSquareConfiguration::getResultImageFormat()
476 {
477     return VK_FORMAT_R32_UINT;
478 }
479 
getResultImageFormatSize()480 size_t SingleSquareConfiguration::getResultImageFormatSize()
481 {
482     return sizeof(uint32_t);
483 }
484 
getClearValue()485 VkClearValue SingleSquareConfiguration::getClearValue()
486 {
487     return makeClearValueColorU32(0xFF, 0u, 0u, 0u);
488 }
489 
490 class CallableShaderTestCase : public TestCase
491 {
492 public:
493     CallableShaderTestCase(tcu::TestContext &context, const char *name, const TestParams data);
494     ~CallableShaderTestCase(void);
495 
496     virtual void checkSupport(Context &context) const;
497     virtual void initPrograms(SourceCollections &programCollection) const;
498     virtual TestInstance *createInstance(Context &context) const;
499 
500 private:
501     TestParams m_data;
502 };
503 
504 class CallableShaderTestInstance : public TestInstance
505 {
506 public:
507     CallableShaderTestInstance(Context &context, const TestParams &data);
508     ~CallableShaderTestInstance(void);
509     tcu::TestStatus iterate(void);
510 
511 protected:
512     de::MovePtr<BufferWithMemory> runTest();
513 
514 private:
515     TestParams m_data;
516 };
517 
CallableShaderTestCase(tcu::TestContext & context,const char * name,const TestParams data)518 CallableShaderTestCase::CallableShaderTestCase(tcu::TestContext &context, const char *name, const TestParams data)
519     : vkt::TestCase(context, name)
520     , m_data(data)
521 {
522 }
523 
~CallableShaderTestCase(void)524 CallableShaderTestCase::~CallableShaderTestCase(void)
525 {
526 }
527 
checkSupport(Context & context) const528 void CallableShaderTestCase::checkSupport(Context &context) const
529 {
530     context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
531     context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
532 
533     const VkPhysicalDeviceRayTracingPipelineFeaturesKHR &rayTracingPipelineFeaturesKHR =
534         context.getRayTracingPipelineFeatures();
535     if (rayTracingPipelineFeaturesKHR.rayTracingPipeline == false)
536         TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceRayTracingPipelineFeaturesKHR.rayTracingPipeline");
537 
538     const VkPhysicalDeviceAccelerationStructureFeaturesKHR &accelerationStructureFeaturesKHR =
539         context.getAccelerationStructureFeatures();
540     if (accelerationStructureFeaturesKHR.accelerationStructure == false)
541         TCU_THROW(TestError, "VK_KHR_ray_tracing_pipeline requires "
542                              "VkPhysicalDeviceAccelerationStructureFeaturesKHR.accelerationStructure");
543 }
544 
initPrograms(SourceCollections & programCollection) const545 void CallableShaderTestCase::initPrograms(SourceCollections &programCollection) const
546 {
547     const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
548     {
549         std::stringstream css;
550         css << "#version 460 core\n"
551                "#extension GL_EXT_ray_tracing : require\n"
552                "layout(location = 0) rayPayloadEXT uvec4 hitValue;\n"
553                "layout(r32ui, set = 0, binding = 0) uniform uimage2D result;\n"
554                "layout(set = 0, binding = 1) uniform accelerationStructureEXT topLevelAS;\n"
555                "\n"
556                "void main()\n"
557                "{\n"
558                "  float tmin     = 0.0;\n"
559                "  float tmax     = 1.0;\n"
560                "  vec3  origin   = vec3(float(gl_LaunchIDEXT.x) + 0.5f, float(gl_LaunchIDEXT.y) + 0.5f, 0.5f);\n"
561                "  vec3  direct   = vec3(0.0, 0.0, -1.0);\n"
562                "  hitValue       = uvec4(0,0,0,0);\n"
563                "  traceRayEXT(topLevelAS, 0, 0xFF, 0, 0, 0, origin, tmin, direct, tmax, 0);\n"
564                "  imageStore(result, ivec2(gl_LaunchIDEXT.xy), hitValue);\n"
565                "}\n";
566         programCollection.glslSources.add("rgen") << glu::RaygenSource(updateRayTracingGLSL(css.str())) << buildOptions;
567     }
568 
569     {
570         std::stringstream css;
571         css << "#version 460 core\n"
572                "#extension GL_EXT_ray_tracing : require\n"
573                "layout(location = 0) callableDataEXT uvec4 value;\n"
574                "layout(r32ui, set = 0, binding = 0) uniform uimage2D result;\n"
575                "layout(set = 0, binding = 1) uniform accelerationStructureEXT topLevelAS;\n"
576                "\n"
577                "void main()\n"
578                "{\n"
579                "  executeCallableEXT(0, 0);\n"
580                "  imageStore(result, ivec2(gl_LaunchIDEXT.xy), value);\n"
581                "}\n";
582         programCollection.glslSources.add("rgen_call")
583             << glu::RaygenSource(updateRayTracingGLSL(css.str())) << buildOptions;
584     }
585 
586     {
587         std::stringstream css;
588         css << "#version 460 core\n"
589                "#extension GL_EXT_ray_tracing : require\n"
590                "struct CallValue\n"
591                "{\n"
592                "  ivec4 a;\n"
593                "  vec4  b;\n"
594                "};\n"
595                "layout(location = 0) callableDataEXT uvec4 value0;\n"
596                "layout(location = 1) callableDataEXT uint value1;\n"
597                "layout(location = 2) callableDataEXT CallValue value2;\n"
598                "layout(location = 4) callableDataEXT vec3 value3;\n"
599                "layout(r32ui, set = 0, binding = 0) uniform uimage2D result;\n"
600                "layout(set = 0, binding = 1) uniform accelerationStructureEXT topLevelAS;\n"
601                "\n"
602                "void main()\n"
603                "{\n"
604                "  executeCallableEXT(0, 0);\n"
605                "  executeCallableEXT(1, 1);\n"
606                "  executeCallableEXT(2, 2);\n"
607                "  executeCallableEXT(3, 4);\n"
608                "  uint resultValue = value0.x + value1 + value2.a.x * uint(floor(value2.b.y)) + "
609                "uint(floor(value3.z));\n"
610                "  imageStore(result, ivec2(gl_LaunchIDEXT.xy), uvec4(resultValue, 0, 0, 0));\n"
611                "}\n";
612         programCollection.glslSources.add("rgen_multicall")
613             << glu::RaygenSource(updateRayTracingGLSL(css.str())) << buildOptions;
614     }
615 
616     {
617         std::stringstream css;
618         css << "#version 460 core\n"
619                "#extension GL_EXT_ray_tracing : require\n"
620                "layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
621                "void main()\n"
622                "{\n"
623                "  hitValue = uvec4(1,0,0,1);\n"
624                "}\n";
625 
626         programCollection.glslSources.add("chit")
627             << glu::ClosestHitSource(updateRayTracingGLSL(css.str())) << buildOptions;
628     }
629 
630     {
631         std::stringstream css;
632         css << "#version 460 core\n"
633                "#extension GL_EXT_ray_tracing : require\n"
634                "layout(location = 0) callableDataEXT uvec4 value;\n"
635                "layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
636                "void main()\n"
637                "{\n"
638                "  executeCallableEXT(0, 0);\n"
639                "  hitValue = value;\n"
640                "  hitValue.x = hitValue.x + 1;\n"
641                "}\n";
642 
643         programCollection.glslSources.add("chit_call")
644             << glu::ClosestHitSource(updateRayTracingGLSL(css.str())) << buildOptions;
645     }
646 
647     {
648         std::stringstream css;
649         css << "#version 460 core\n"
650                "#extension GL_EXT_ray_tracing : require\n"
651                "layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
652                "void main()\n"
653                "{\n"
654                "  hitValue = uvec4(0,0,0,1);\n"
655                "}\n";
656 
657         programCollection.glslSources.add("miss") << glu::MissSource(updateRayTracingGLSL(css.str())) << buildOptions;
658     }
659 
660     {
661         std::stringstream css;
662         css << "#version 460 core\n"
663                "#extension GL_EXT_ray_tracing : require\n"
664                "layout(location = 0) callableDataEXT uvec4 value;\n"
665                "layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
666                "void main()\n"
667                "{\n"
668                "  executeCallableEXT(0, 0);\n"
669                "  hitValue = value;\n"
670                "}\n";
671 
672         programCollection.glslSources.add("miss_call")
673             << glu::MissSource(updateRayTracingGLSL(css.str())) << buildOptions;
674     }
675 
676     std::vector<std::string> callableDataDefinition = {
677         "layout(location = 0) callableDataInEXT uvec4 result;\n",
678         "layout(location = 1) callableDataInEXT uint result;\n",
679         "struct CallValue\n{\n  ivec4 a;\n  vec4  b;\n};\nlayout(location = 2) callableDataInEXT CallValue result;\n",
680         "layout(location = 4) callableDataInEXT vec3 result;\n"};
681 
682     std::vector<std::string> callableDataComputation = {
683         "  result = uvec4(1,0,0,1);\n",
684         "  result = 2;\n",
685         "  result.a = ivec4(3,0,0,1);\n  result.b = vec4(1.0, 3.2, 0.0, 1);\n",
686         "  result = vec3(0.0, 0.0, 4.3);\n",
687     };
688 
689     for (uint32_t idx = 0; idx < callableDataDefinition.size(); ++idx)
690     {
691         std::stringstream css;
692         css << "#version 460 core\n"
693                "#extension GL_EXT_ray_tracing : require\n"
694             << callableDataDefinition[idx]
695             << "void main()\n"
696                "{\n"
697             << callableDataComputation[idx] << "}\n";
698         std::stringstream csname;
699         csname << "call_" << idx;
700 
701         programCollection.glslSources.add(csname.str())
702             << glu::CallableSource(updateRayTracingGLSL(css.str())) << buildOptions;
703     }
704 
705     {
706         std::stringstream css;
707         css << "#version 460 core\n"
708                "#extension GL_EXT_ray_tracing : require\n"
709                "layout(location = 0) callableDataInEXT uvec4 result;\n"
710                "layout(location = 1) callableDataEXT uvec4 info;\n"
711                "void main()\n"
712                "{\n"
713                "  executeCallableEXT(1, 1);\n"
714                "  result = info;\n"
715                "}\n";
716 
717         programCollection.glslSources.add("call_call")
718             << glu::CallableSource(updateRayTracingGLSL(css.str())) << buildOptions;
719     }
720 }
721 
createInstance(Context & context) const722 TestInstance *CallableShaderTestCase::createInstance(Context &context) const
723 {
724     return new CallableShaderTestInstance(context, m_data);
725 }
726 
CallableShaderTestInstance(Context & context,const TestParams & data)727 CallableShaderTestInstance::CallableShaderTestInstance(Context &context, const TestParams &data)
728     : vkt::TestInstance(context)
729     , m_data(data)
730 {
731 }
732 
~CallableShaderTestInstance(void)733 CallableShaderTestInstance::~CallableShaderTestInstance(void)
734 {
735 }
736 
runTest()737 de::MovePtr<BufferWithMemory> CallableShaderTestInstance::runTest()
738 {
739     const InstanceInterface &vki          = m_context.getInstanceInterface();
740     const DeviceInterface &vkd            = m_context.getDeviceInterface();
741     const VkDevice device                 = m_context.getDevice();
742     const VkPhysicalDevice physicalDevice = m_context.getPhysicalDevice();
743     const uint32_t queueFamilyIndex       = m_context.getUniversalQueueFamilyIndex();
744     const VkQueue queue                   = m_context.getUniversalQueue();
745     Allocator &allocator                  = m_context.getDefaultAllocator();
746     const uint32_t pixelCount             = m_data.width * m_data.height * 1;
747 
748     const Move<VkDescriptorSetLayout> descriptorSetLayout =
749         DescriptorSetLayoutBuilder()
750             .addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, ALL_RAY_TRACING_STAGES)
751             .addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, ALL_RAY_TRACING_STAGES)
752             .build(vkd, device);
753     const Move<VkDescriptorPool> descriptorPool =
754         DescriptorPoolBuilder()
755             .addType(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE)
756             .addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR)
757             .build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
758     const Move<VkDescriptorSet> descriptorSet   = makeDescriptorSet(vkd, device, *descriptorPool, *descriptorSetLayout);
759     const Move<VkPipelineLayout> pipelineLayout = makePipelineLayout(vkd, device, descriptorSetLayout.get());
760 
761     de::MovePtr<RayTracingPipeline> rayTracingPipeline = de::newMovePtr<RayTracingPipeline>();
762     m_data.testConfiguration->initRayTracingShaders(rayTracingPipeline, m_context, m_data);
763     Move<VkPipeline> pipeline = rayTracingPipeline->createPipeline(vkd, device, *pipelineLayout);
764 
765     de::MovePtr<BufferWithMemory> raygenShaderBindingTable;
766     de::MovePtr<BufferWithMemory> hitShaderBindingTable;
767     de::MovePtr<BufferWithMemory> missShaderBindingTable;
768     de::MovePtr<BufferWithMemory> callableShaderBindingTable;
769     VkStridedDeviceAddressRegionKHR raygenShaderBindingTableRegion;
770     VkStridedDeviceAddressRegionKHR hitShaderBindingTableRegion;
771     VkStridedDeviceAddressRegionKHR missShaderBindingTableRegion;
772     VkStridedDeviceAddressRegionKHR callableShaderBindingTableRegion;
773     m_data.testConfiguration->initShaderBindingTables(
774         rayTracingPipeline, m_context, m_data, *pipeline, getShaderGroupHandleSize(vki, physicalDevice),
775         getShaderGroupBaseAlignment(vki, physicalDevice), raygenShaderBindingTable, hitShaderBindingTable,
776         missShaderBindingTable, callableShaderBindingTable, raygenShaderBindingTableRegion, hitShaderBindingTableRegion,
777         missShaderBindingTableRegion, callableShaderBindingTableRegion);
778 
779     const VkFormat imageFormat              = m_data.testConfiguration->getResultImageFormat();
780     const VkImageCreateInfo imageCreateInfo = makeImageCreateInfo(m_data.width, m_data.height, imageFormat);
781     const VkImageSubresourceRange imageSubresourceRange =
782         makeImageSubresourceRange(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 1u, 0u, 1u);
783     const de::MovePtr<ImageWithMemory> image = de::MovePtr<ImageWithMemory>(
784         new ImageWithMemory(vkd, device, allocator, imageCreateInfo, MemoryRequirement::Any));
785     const Move<VkImageView> imageView =
786         makeImageView(vkd, device, **image, VK_IMAGE_VIEW_TYPE_2D, imageFormat, imageSubresourceRange);
787 
788     const VkBufferCreateInfo resultBufferCreateInfo = makeBufferCreateInfo(
789         pixelCount * m_data.testConfiguration->getResultImageFormatSize(), VK_BUFFER_USAGE_TRANSFER_DST_BIT);
790     const VkImageSubresourceLayers resultBufferImageSubresourceLayers =
791         makeImageSubresourceLayers(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 0u, 1u);
792     const VkBufferImageCopy resultBufferImageRegion =
793         makeBufferImageCopy(makeExtent3D(m_data.width, m_data.height, 1), resultBufferImageSubresourceLayers);
794     de::MovePtr<BufferWithMemory> resultBuffer = de::MovePtr<BufferWithMemory>(
795         new BufferWithMemory(vkd, device, allocator, resultBufferCreateInfo, MemoryRequirement::HostVisible));
796 
797     const VkDescriptorImageInfo descriptorImageInfo =
798         makeDescriptorImageInfo(DE_NULL, *imageView, VK_IMAGE_LAYOUT_GENERAL);
799 
800     const Move<VkCommandPool> cmdPool = createCommandPool(vkd, device, 0, queueFamilyIndex);
801     const Move<VkCommandBuffer> cmdBuffer =
802         allocateCommandBuffer(vkd, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
803 
804     std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> bottomLevelAccelerationStructures;
805     de::MovePtr<TopLevelAccelerationStructure> topLevelAccelerationStructure;
806 
807     beginCommandBuffer(vkd, *cmdBuffer, 0u);
808     {
809         const VkImageMemoryBarrier preImageBarrier =
810             makeImageMemoryBarrier(0u, VK_ACCESS_TRANSFER_WRITE_BIT, VK_IMAGE_LAYOUT_UNDEFINED,
811                                    VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, **image, imageSubresourceRange);
812         cmdPipelineImageMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT,
813                                       VK_PIPELINE_STAGE_TRANSFER_BIT, &preImageBarrier);
814 
815         const VkClearValue clearValue = m_data.testConfiguration->getClearValue();
816         vkd.cmdClearColorImage(*cmdBuffer, **image, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, &clearValue.color, 1,
817                                &imageSubresourceRange);
818 
819         const VkImageMemoryBarrier postImageBarrier = makeImageMemoryBarrier(
820             VK_ACCESS_TRANSFER_WRITE_BIT,
821             VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR,
822             VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, VK_IMAGE_LAYOUT_GENERAL, **image, imageSubresourceRange);
823         cmdPipelineImageMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT,
824                                       VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR, &postImageBarrier);
825 
826         bottomLevelAccelerationStructures =
827             m_data.testConfiguration->initBottomAccelerationStructures(m_context, m_data);
828         for (auto &blas : bottomLevelAccelerationStructures)
829             blas->createAndBuild(vkd, device, *cmdBuffer, allocator);
830         topLevelAccelerationStructure = m_data.testConfiguration->initTopAccelerationStructure(
831             m_context, m_data, bottomLevelAccelerationStructures);
832         topLevelAccelerationStructure->createAndBuild(vkd, device, *cmdBuffer, allocator);
833 
834         const TopLevelAccelerationStructure *topLevelAccelerationStructurePtr = topLevelAccelerationStructure.get();
835         VkWriteDescriptorSetAccelerationStructureKHR accelerationStructureWriteDescriptorSet = {
836             VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR, //  VkStructureType sType;
837             DE_NULL,                                                           //  const void* pNext;
838             1u,                                                                //  uint32_t accelerationStructureCount;
839             topLevelAccelerationStructurePtr->getPtr(), //  const VkAccelerationStructureKHR* pAccelerationStructures;
840         };
841 
842         DescriptorSetUpdateBuilder()
843             .writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(0u),
844                          VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, &descriptorImageInfo)
845             .writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(1u),
846                          VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &accelerationStructureWriteDescriptorSet)
847             .update(vkd, device);
848 
849         vkd.cmdBindDescriptorSets(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipelineLayout, 0, 1,
850                                   &descriptorSet.get(), 0, DE_NULL);
851 
852         vkd.cmdBindPipeline(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipeline);
853 
854         cmdTraceRays(vkd, *cmdBuffer, &raygenShaderBindingTableRegion, &missShaderBindingTableRegion,
855                      &hitShaderBindingTableRegion, &callableShaderBindingTableRegion, m_data.width, m_data.height, 1);
856 
857         const VkMemoryBarrier postTraceMemoryBarrier =
858             makeMemoryBarrier(VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_TRANSFER_READ_BIT);
859         const VkMemoryBarrier postCopyMemoryBarrier =
860             makeMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
861         cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR,
862                                  VK_PIPELINE_STAGE_TRANSFER_BIT, &postTraceMemoryBarrier);
863 
864         vkd.cmdCopyImageToBuffer(*cmdBuffer, **image, VK_IMAGE_LAYOUT_GENERAL, **resultBuffer, 1u,
865                                  &resultBufferImageRegion);
866 
867         cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_HOST_BIT,
868                                  &postCopyMemoryBarrier);
869     }
870     endCommandBuffer(vkd, *cmdBuffer);
871 
872     submitCommandsAndWait(vkd, device, queue, cmdBuffer.get());
873 
874     invalidateMappedMemoryRange(vkd, device, resultBuffer->getAllocation().getMemory(),
875                                 resultBuffer->getAllocation().getOffset(), VK_WHOLE_SIZE);
876 
877     return resultBuffer;
878 }
879 
iterate(void)880 tcu::TestStatus CallableShaderTestInstance::iterate(void)
881 {
882     // run test using arrays of pointers
883     const de::MovePtr<BufferWithMemory> buffer = runTest();
884 
885     if (!m_data.testConfiguration->verifyImage(buffer.get(), m_context, m_data))
886         return tcu::TestStatus::fail("Fail");
887     return tcu::TestStatus::pass("Pass");
888 }
889 
890 constexpr uint32_t callableDataUintLoc    = 0;
891 constexpr uint32_t callableDataFloatLoc   = 1;
892 constexpr uint32_t callableDataUintOutLoc = 2;
893 
894 struct CallableBuffer0
895 {
896     uint32_t base;
897     uint32_t shift;
898     uint32_t offset;
899     uint32_t multiplier;
900 };
901 
902 struct CallableBuffer1
903 {
904     float numerator;
905     float denomenator;
906     uint32_t power;
907 };
908 
909 struct Ray
910 {
Rayvkt::RayTracing::__anon6b138d0d0111::Ray911     Ray() : o(0.0f), tmin(0.0f), d(0.0f), tmax(0.0f)
912     {
913     }
Rayvkt::RayTracing::__anon6b138d0d0111::Ray914     Ray(const tcu::Vec3 &io, float imin, const tcu::Vec3 &id, float imax) : o(io), tmin(imin), d(id), tmax(imax)
915     {
916     }
917     tcu::Vec3 o;
918     float tmin;
919     tcu::Vec3 d;
920     float tmax;
921 };
922 
923 constexpr float MAX_T_VALUE = 1000.0f;
924 
AddVertexLayers(std::vector<tcu::Vec3> * pVerts,uint32_t newLayers)925 void AddVertexLayers(std::vector<tcu::Vec3> *pVerts, uint32_t newLayers)
926 {
927     size_t vertsPerLayer = pVerts->size();
928     size_t totalLayers   = 1 + newLayers;
929 
930     pVerts->reserve(pVerts->size() * totalLayers);
931 
932     for (size_t layer = 0; layer < newLayers; ++layer)
933     {
934         for (size_t vert = 0; vert < vertsPerLayer; ++vert)
935         {
936             bool flippedLayer = (layer % 2) == 0;
937             tcu::Vec3 stage   = (*pVerts)[flippedLayer ? (vertsPerLayer - vert - 1) : vert];
938             ++stage.z();
939 
940             pVerts->push_back(stage);
941         }
942     }
943 }
944 
compareFloat(float actual,float expected)945 bool compareFloat(float actual, float expected)
946 {
947     constexpr float eps = 0.01f;
948     bool success        = true;
949 
950     if (abs(expected - actual) > eps)
951     {
952         success = false;
953     }
954 
955     return success;
956 }
957 
958 class InvokeCallableShaderTestCase : public TestCase
959 {
960 public:
961     InvokeCallableShaderTestCase(tcu::TestContext &context, const char *name, const TestParams &data);
962     ~InvokeCallableShaderTestCase(void);
963 
964     virtual void checkSupport(Context &context) const;
965     virtual void initPrograms(SourceCollections &programCollection) const;
966     virtual TestInstance *createInstance(Context &context) const;
967 
968 private:
969     TestParams params;
970 };
971 
972 class InvokeCallableShaderTestInstance : public TestInstance
973 {
974 public:
975     InvokeCallableShaderTestInstance(Context &context, const TestParams &data);
976     ~InvokeCallableShaderTestInstance(void);
977     tcu::TestStatus iterate(void);
978 
979 private:
980     TestParams params;
981 };
982 
InvokeCallableShaderTestCase(tcu::TestContext & context,const char * name,const TestParams & data)983 InvokeCallableShaderTestCase::InvokeCallableShaderTestCase(tcu::TestContext &context, const char *name,
984                                                            const TestParams &data)
985     : vkt::TestCase(context, name)
986     , params(data)
987 {
988 }
989 
~InvokeCallableShaderTestCase(void)990 InvokeCallableShaderTestCase::~InvokeCallableShaderTestCase(void)
991 {
992 }
993 
checkSupport(Context & context) const994 void InvokeCallableShaderTestCase::checkSupport(Context &context) const
995 {
996     context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
997     context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
998 
999     const VkPhysicalDeviceRayTracingPipelineFeaturesKHR &rayTracingPipelineFeaturesKHR =
1000         context.getRayTracingPipelineFeatures();
1001     if (rayTracingPipelineFeaturesKHR.rayTracingPipeline == false)
1002         TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceRayTracingPipelineFeaturesKHR.rayTracingPipeline");
1003 
1004     const VkPhysicalDeviceAccelerationStructureFeaturesKHR &accelerationStructureFeaturesKHR =
1005         context.getAccelerationStructureFeatures();
1006     if (accelerationStructureFeaturesKHR.accelerationStructure == false)
1007         TCU_THROW(TestError, "VK_KHR_ray_tracing_pipeline requires "
1008                              "VkPhysicalDeviceAccelerationStructureFeaturesKHR.accelerationStructure");
1009 }
1010 
1011 //resultData:
1012 // x - value0
1013 // y - value1
1014 // z - value2
1015 // w - closestT
verifyResultData(const tcu::Vec4 * resultData,uint32_t index,bool hit,const TestParams & params)1016 bool verifyResultData(const tcu::Vec4 *resultData, uint32_t index, bool hit, const TestParams &params)
1017 {
1018     bool success = true;
1019 
1020     float refValue0 = 0.0f;
1021     float refValue1 = 0.0f;
1022     float refValue2 = 0.0f;
1023 
1024     if (hit)
1025     {
1026         switch (params.invokingShader)
1027         {
1028         case glu::SHADERTYPE_RAYGEN:
1029         case glu::SHADERTYPE_CLOSEST_HIT:
1030         case glu::SHADERTYPE_CALLABLE:
1031             refValue0 = 133.0f;
1032             break;
1033         case glu::SHADERTYPE_MISS:
1034             break;
1035         default:
1036             TCU_THROW(InternalError, "Wrong shader invoking type");
1037             break;
1038         }
1039 
1040         if (params.multipleInvocations)
1041         {
1042             switch (params.invokingShader)
1043             {
1044             case glu::SHADERTYPE_RAYGEN:
1045                 refValue1 = 17.64f;
1046                 refValue2 = 35.28f;
1047                 break;
1048             case glu::SHADERTYPE_CLOSEST_HIT:
1049                 refValue1 = 17.64f;
1050                 refValue2 = index < 4 ? 35.28f : 8.82f;
1051                 break;
1052             case glu::SHADERTYPE_CALLABLE:
1053                 refValue1 = 17.64f;
1054                 refValue2 = index < 6 ? 35.28f : 8.82f;
1055                 break;
1056             case glu::SHADERTYPE_MISS:
1057                 break;
1058             default:
1059                 TCU_THROW(InternalError, "Wrong shader invoking type");
1060                 break;
1061             }
1062         }
1063 
1064         if (!compareFloat(resultData->w(), 2.0f))
1065         {
1066             success = false;
1067         }
1068     }
1069 
1070     if (!hit)
1071     {
1072         switch (params.invokingShader)
1073         {
1074         case glu::SHADERTYPE_RAYGEN:
1075         case glu::SHADERTYPE_MISS:
1076         case glu::SHADERTYPE_CALLABLE:
1077             refValue0 = 133.0f;
1078             break;
1079         case glu::SHADERTYPE_CLOSEST_HIT:
1080             break;
1081         default:
1082             TCU_THROW(InternalError, "Wrong shader invoking type");
1083             break;
1084         }
1085 
1086         if (params.multipleInvocations)
1087         {
1088             switch (params.invokingShader)
1089             {
1090             case glu::SHADERTYPE_RAYGEN:
1091                 refValue1 = 17.64f;
1092                 refValue2 = 8.82f;
1093                 break;
1094             case glu::SHADERTYPE_MISS:
1095                 refValue1 = 17.64f;
1096                 refValue2 = index < 10 ? 35.28f : 8.82f;
1097                 break;
1098             case glu::SHADERTYPE_CALLABLE:
1099                 refValue1 = 17.64f;
1100                 refValue2 = index < 6 ? 35.28f : 8.82f;
1101                 break;
1102             case glu::SHADERTYPE_CLOSEST_HIT:
1103                 break;
1104             default:
1105                 TCU_THROW(InternalError, "Wrong shader invoking type");
1106                 break;
1107             }
1108         }
1109 
1110         if (!compareFloat(resultData->w(), MAX_T_VALUE))
1111         {
1112             success = false;
1113         }
1114     }
1115 
1116     if ((!compareFloat(resultData->x(), refValue0)) || (!compareFloat(resultData->y(), refValue1)) ||
1117         (!compareFloat(resultData->z(), refValue2)))
1118     {
1119         success = false;
1120     }
1121 
1122     return success;
1123 }
1124 
getRayGenSource(bool invokeCallable,bool multiInvoke)1125 std::string getRayGenSource(bool invokeCallable, bool multiInvoke)
1126 {
1127     std::ostringstream src;
1128     src << "struct Payload { uint lastShader; float closestT; };\n"
1129            "layout(location = 0) rayPayloadEXT Payload payload;\n";
1130 
1131     if (invokeCallable)
1132     {
1133         src << "#define CALLABLE_DATA_UINT_LOC " << callableDataUintLoc
1134             << "\n"
1135                "layout(location = CALLABLE_DATA_UINT_LOC) callableDataEXT uint callableDataUint;\n";
1136 
1137         if (multiInvoke)
1138         {
1139             src << "#define CALLABLE_DATA_FLOAT_LOC " << callableDataFloatLoc
1140                 << "\n"
1141                    "layout(location = CALLABLE_DATA_FLOAT_LOC) callableDataEXT float callableDataFloat;\n";
1142         }
1143     }
1144 
1145     src << "void main() {\n"
1146            "   uint index = launchIndex();\n"
1147            "   Ray ray = rays[index];\n"
1148            "   results[index].value0 = 0;\n"
1149            "   results[index].value1 = 0;\n"
1150            "   results[index].value2 = 0;\n";
1151 
1152     if (invokeCallable)
1153     {
1154         src << "   callableDataUint = "
1155             << "0"
1156             << ";\n"
1157                "   executeCallableEXT(0, CALLABLE_DATA_UINT_LOC);\n"
1158                "   results[index].value0 = float(callableDataUint);\n";
1159 
1160         if (multiInvoke)
1161         {
1162             src << "   callableDataFloat = 0.0;\n"
1163                    "   executeCallableEXT(1, CALLABLE_DATA_FLOAT_LOC);\n"
1164                    "   results[index].value1 = callableDataFloat;\n";
1165         }
1166     }
1167 
1168     src << "   payload.lastShader = " << glu::SHADERTYPE_RAYGEN
1169         << ";\n"
1170            "   payload.closestT = "
1171         << MAX_T_VALUE
1172         << ";\n"
1173            "   traceRayEXT(scene, 0x0, 0xff, 0, 0, 0, ray.pos, "
1174         << "ray.tmin"
1175         << ", ray.dir, ray.tmax, 0);\n";
1176 
1177     if (invokeCallable && multiInvoke)
1178     {
1179         src << "   executeCallableEXT(payload.lastShader == " << glu::SHADERTYPE_CLOSEST_HIT
1180             << " ? 1 : 2, CALLABLE_DATA_FLOAT_LOC);\n"
1181                "   results[index].value2 = callableDataFloat;\n";
1182     }
1183 
1184     src << "   results[index].closestT = payload.closestT;\n"
1185            "}";
1186 
1187     return src.str();
1188 }
1189 
getClosestHitSource(bool invokeCallable,bool multiInvoke)1190 std::string getClosestHitSource(bool invokeCallable, bool multiInvoke)
1191 {
1192     std::ostringstream src;
1193     src << "struct Payload { uint lastShader; float closestT; };\n"
1194            "layout(location = 0) rayPayloadInEXT Payload payload;\n";
1195 
1196     if (invokeCallable)
1197     {
1198         src << "#define CALLABLE_DATA_UINT_LOC " << callableDataUintLoc
1199             << "\n"
1200                "layout(location = CALLABLE_DATA_UINT_LOC) callableDataEXT uint callableDataUint;\n";
1201 
1202         if (multiInvoke)
1203         {
1204             src << "#define CALLABLE_DATA_FLOAT_LOC " << callableDataFloatLoc
1205                 << "\n"
1206                    "layout(location = CALLABLE_DATA_FLOAT_LOC) callableDataEXT float callableDataFloat;\n";
1207         }
1208     }
1209 
1210     src << "void main() {\n"
1211            "   payload.lastShader = "
1212         << glu::SHADERTYPE_CLOSEST_HIT
1213         << ";\n"
1214            "   payload.closestT = gl_HitTEXT;\n";
1215 
1216     if (invokeCallable)
1217     {
1218         src << "   uint index = launchIndex();\n"
1219                "   callableDataUint = 0;\n"
1220                "   executeCallableEXT(0, CALLABLE_DATA_UINT_LOC);\n"
1221                "   results[index].value0 = float(callableDataUint);\n";
1222 
1223         if (multiInvoke)
1224         {
1225             src << "   callableDataFloat = 0.0;\n"
1226                    "   executeCallableEXT(1, CALLABLE_DATA_FLOAT_LOC);\n"
1227                    "   results[index].value1 = callableDataFloat;\n"
1228                    "   executeCallableEXT(index < 4 ? 1 : 2, CALLABLE_DATA_FLOAT_LOC);\n"
1229                    "   results[index].value2 = callableDataFloat;\n";
1230         }
1231     }
1232 
1233     src << "}";
1234 
1235     return src.str();
1236 }
1237 
getMissSource(bool invokeCallable,bool multiInvoke)1238 std::string getMissSource(bool invokeCallable, bool multiInvoke)
1239 {
1240     std::ostringstream src;
1241     src << "struct Payload { uint lastShader; float closestT; };\n"
1242            "layout(location = 0) rayPayloadInEXT Payload payload;\n";
1243 
1244     if (invokeCallable)
1245     {
1246         src << "#define CALLABLE_DATA_UINT_LOC " << callableDataUintLoc
1247             << "\n"
1248                "layout(location = CALLABLE_DATA_UINT_LOC) callableDataEXT uint callableDataUint;\n";
1249 
1250         if (multiInvoke)
1251         {
1252             src << "#define CALLABLE_DATA_FLOAT_LOC " << callableDataFloatLoc
1253                 << "\n"
1254                    "layout(location = CALLABLE_DATA_FLOAT_LOC) callableDataEXT float callableDataFloat;\n";
1255         }
1256     }
1257 
1258     src << "void main() {\n"
1259            "   payload.lastShader = "
1260         << glu::SHADERTYPE_MISS << ";\n";
1261 
1262     if (invokeCallable)
1263     {
1264         src << "   uint index = launchIndex();\n"
1265                "   callableDataUint = 0;\n"
1266                "   executeCallableEXT(0, CALLABLE_DATA_UINT_LOC);\n"
1267                "   results[index].value0 = float(callableDataUint);\n";
1268 
1269         if (multiInvoke)
1270         {
1271             src << "   callableDataFloat = 0.0;\n"
1272                    "   executeCallableEXT(1, CALLABLE_DATA_FLOAT_LOC);\n"
1273                    "   results[index].value1 = callableDataFloat;\n"
1274                    "   executeCallableEXT(index < 10 ? 1 : 2, CALLABLE_DATA_FLOAT_LOC);\n"
1275                    "   results[index].value2 = callableDataFloat;\n";
1276         }
1277     }
1278 
1279     src << "}";
1280 
1281     return src.str();
1282 }
1283 
getCallableSource(bool invokeCallable,bool multiInvoke)1284 std::string getCallableSource(bool invokeCallable, bool multiInvoke)
1285 {
1286     std::ostringstream src;
1287     src << "#define CALLABLE_DATA_UINT_LOC " << callableDataUintLoc
1288         << "\n"
1289            "layout(location = CALLABLE_DATA_UINT_LOC) callableDataInEXT uint callableDataUintIn;\n";
1290 
1291     if (invokeCallable)
1292     {
1293         src << "#define CALLABLE_DATA_UINT_OUT_LOC " << callableDataUintOutLoc << "\n"
1294             << "layout(location = CALLABLE_DATA_UINT_OUT_LOC) callableDataEXT uint callableDataUint;\n";
1295 
1296         if (multiInvoke)
1297         {
1298             src << "#define CALLABLE_DATA_FLOAT_LOC " << callableDataFloatLoc
1299                 << "\n"
1300                    "layout(location = CALLABLE_DATA_FLOAT_LOC) callableDataEXT float callableDataFloat;\n";
1301         }
1302     }
1303 
1304     src << "void main() {\n";
1305 
1306     if (invokeCallable)
1307     {
1308         src << "   uint index = launchIndex();\n"
1309                "   callableDataUint = 0;\n"
1310                "   executeCallableEXT(1, CALLABLE_DATA_UINT_OUT_LOC);\n"
1311                "   callableDataUintIn = callableDataUint;\n";
1312 
1313         if (multiInvoke)
1314         {
1315             src << "   callableDataFloat = 0.0;\n"
1316                    "   executeCallableEXT(2, CALLABLE_DATA_FLOAT_LOC);\n"
1317                    "   results[index].value1 = callableDataFloat;\n"
1318                    "   executeCallableEXT(index < 6 ? 2 : 3, CALLABLE_DATA_FLOAT_LOC);\n"
1319                    "   results[index].value2 = callableDataFloat;\n";
1320         }
1321     }
1322 
1323     src << "}";
1324 
1325     return src.str();
1326 }
1327 
1328 constexpr uint32_t DefaultResultBinding = 0;
1329 constexpr uint32_t DefaultSceneBinding  = 1;
1330 constexpr uint32_t DefaultRaysBinding   = 2;
1331 
1332 enum ShaderSourceFlag
1333 {
1334     DEFINE_RAY             = 0x1,
1335     DEFINE_RESULT_BUFFER   = 0x2,
1336     DEFINE_SCENE           = 0x4,
1337     DEFINE_RAY_BUFFER      = 0x8,
1338     DEFINE_SIMPLE_BINDINGS = DEFINE_RESULT_BUFFER | DEFINE_SCENE | DEFINE_RAY_BUFFER
1339 };
1340 
generateShaderSource(const char * body,const char * resultType="",uint32_t flags=0,const char * prefix="")1341 static inline std::string generateShaderSource(const char *body, const char *resultType = "", uint32_t flags = 0,
1342                                                const char *prefix = "")
1343 {
1344     std::ostringstream src;
1345     src << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n";
1346 
1347     src << "#extension GL_EXT_ray_tracing : enable\n";
1348 
1349     src << prefix << "\n";
1350 
1351     if (flags & DEFINE_SIMPLE_BINDINGS)
1352         flags |= DEFINE_RAY_BUFFER;
1353 
1354     if (flags & DEFINE_RAY_BUFFER)
1355         flags |= DEFINE_RAY;
1356 
1357     if (flags & DEFINE_RAY)
1358     {
1359         src << "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n";
1360     }
1361 
1362     if (flags & DEFINE_RESULT_BUFFER)
1363         src << "layout(std430, set = 0, binding = " << DefaultResultBinding << ") buffer Results { " << resultType
1364             << " results[]; };\n";
1365 
1366     if (flags & DEFINE_SCENE)
1367     {
1368         src << "layout(set = 0, binding = " << DefaultSceneBinding << ") uniform accelerationStructureEXT scene;\n";
1369     }
1370 
1371     if (flags & DEFINE_RAY_BUFFER)
1372         src << "layout(std430, set = 0, binding = " << DefaultRaysBinding << ") buffer Rays { Ray rays[]; };\n";
1373 
1374     src << "uint launchIndex() { return gl_LaunchIDEXT.z*gl_LaunchSizeEXT.x*gl_LaunchSizeEXT.y + "
1375            "gl_LaunchIDEXT.y*gl_LaunchSizeEXT.x + gl_LaunchIDEXT.x; }\n";
1376 
1377     src << body;
1378 
1379     return src.str();
1380 }
1381 
1382 template <typename T>
addShaderSource(SourceCollections & programCollection,const char * identifier,const char * body,const char * resultType="",uint32_t flags=0,const char * prefix="",uint32_t validatorOptions=0U)1383 inline void addShaderSource(SourceCollections &programCollection, const char *identifier, const char *body,
1384                             const char *resultType = "", uint32_t flags = 0, const char *prefix = "",
1385                             uint32_t validatorOptions = 0U)
1386 {
1387     std::string text = generateShaderSource(body, resultType, flags, prefix);
1388 
1389     const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4,
1390                                               validatorOptions, true);
1391     programCollection.glslSources.add(identifier) << T(text) << buildOptions;
1392 }
1393 
initPrograms(SourceCollections & programCollection) const1394 void InvokeCallableShaderTestCase::initPrograms(SourceCollections &programCollection) const
1395 {
1396     addShaderSource<glu::RaygenSource>(programCollection, "build-raygen", getRayGenSource(false, false).c_str(),
1397                                        "Result", DEFINE_RAY_BUFFER | DEFINE_SIMPLE_BINDINGS,
1398                                        "struct Result { float value0; float value1; float value2; float closestT;};");
1399 
1400     addShaderSource<glu::RaygenSource>(programCollection, "build-raygen-invoke-callable",
1401                                        getRayGenSource(true, false).c_str(), "Result",
1402                                        DEFINE_RAY_BUFFER | DEFINE_SIMPLE_BINDINGS,
1403                                        "struct Result { float value0; float value1; float value2; float closestT;};");
1404 
1405     addShaderSource<glu::ClosestHitSource>(programCollection, "build-closesthit",
1406                                            getClosestHitSource(false, false).c_str(), "", 0, "");
1407 
1408     addShaderSource<glu::MissSource>(programCollection, "build-miss", getMissSource(false, false).c_str(), "", 0, "");
1409 
1410     const std::string RAY_PAYLOAD         = "rayPayloadEXT";
1411     const std::string TRACE_RAY           = "traceRayEXT";
1412     const std::string RAY_PAYLOAD_IN      = "rayPayloadInEXT";
1413     const std::string HIT_ATTRIBUTE       = "hitAttributeEXT";
1414     const std::string REPORT_INTERSECTION = "reportIntersectionEXT";
1415 
1416     const std::string SHADER_RECORD    = "shaderRecordEXT";
1417     const std::string CALLABLE_DATA_IN = "callableDataInEXT";
1418     const std::string CALLABLE_DATA    = "callableDataEXT";
1419     const std::string EXECUTE_CALLABE  = "executeCallableEXT";
1420 
1421     std::ostringstream src;
1422     src << "#define CALLABLE_DATA_UINT_LOC " << callableDataUintLoc
1423         << "\n"
1424            "layout(location = CALLABLE_DATA_UINT_LOC) callableDataInEXT uint callableDataUint;\n"
1425            "layout("
1426         << SHADER_RECORD
1427         << ") buffer callableBuffer\n"
1428            "{\n"
1429            "   uint base;\n"
1430            "   uint shift;\n"
1431            "   uint offset;\n"
1432            "   uint multiplier;\n"
1433            "};\n"
1434            "void main() {\n"
1435            "   callableDataUint += ((base << shift) + offset) * multiplier;\n"
1436            "}";
1437 
1438     addShaderSource<glu::CallableSource>(programCollection, "build-callable-0", src.str().c_str(), "", 0, "");
1439 
1440     if (params.multipleInvocations)
1441     {
1442         switch (params.invokingShader)
1443         {
1444         case glu::SHADERTYPE_RAYGEN:
1445             addShaderSource<glu::RaygenSource>(
1446                 programCollection, "build-raygen-invoke-callable-multi", getRayGenSource(true, true).c_str(), "Result",
1447                 DEFINE_RAY_BUFFER | DEFINE_SIMPLE_BINDINGS,
1448                 "struct Result { float value0; float value1; float value2; float closestT;};");
1449 
1450             break;
1451         case glu::SHADERTYPE_CLOSEST_HIT:
1452             addShaderSource<glu::ClosestHitSource>(
1453                 programCollection, "build-closesthit-invoke-callable-multi", getClosestHitSource(true, true).c_str(),
1454                 "Result", DEFINE_RESULT_BUFFER,
1455                 "struct Result { float value0; float value1; float value2; float closestT;};");
1456 
1457             break;
1458         case glu::SHADERTYPE_MISS:
1459             addShaderSource<glu::MissSource>(
1460                 programCollection, "build-miss-invoke-callable-multi", getMissSource(true, true).c_str(), "Result",
1461                 DEFINE_RESULT_BUFFER, "struct Result { float value0; float value1; float value2; float closestT;};");
1462 
1463             break;
1464         case glu::SHADERTYPE_CALLABLE:
1465             addShaderSource<glu::CallableSource>(
1466                 programCollection, "build-callable-invoke-callable-multi", getCallableSource(true, true).c_str(),
1467                 "Result", DEFINE_RESULT_BUFFER,
1468                 "struct Result { float value0; float value1; float value2; float closestT;};");
1469 
1470             break;
1471         default:
1472             TCU_THROW(InternalError, "Wrong shader invoking type");
1473             break;
1474         }
1475 
1476         src.str(std::string());
1477         src << "#define CALLABLE_DATA_FLOAT_LOC " << callableDataFloatLoc
1478             << "\n"
1479                "layout(location = CALLABLE_DATA_FLOAT_LOC) callableDataInEXT float callableDataFloat;\n"
1480                "layout("
1481             << SHADER_RECORD
1482             << ") buffer callableBuffer\n"
1483                "{\n"
1484                "   float numerator;\n"
1485                "   float denomenator;\n"
1486                "   uint power;\n"
1487                "   uint reserved;\n"
1488                "};\n"
1489                "void main() {\n"
1490                "   float base = numerator / denomenator;\n"
1491                "   float result = 1;\n"
1492                "   for (uint i = 0; i < power; ++i)\n"
1493                "   {\n"
1494                "      result *= base;\n"
1495                "   }\n"
1496                "   callableDataFloat += result;\n"
1497                "}";
1498 
1499         addShaderSource<glu::CallableSource>(programCollection, "build-callable-1", src.str().c_str(), "", 0, "");
1500 
1501         src.str(std::string());
1502         src << "#define CALLABLE_DATA_FLOAT_LOC " << callableDataFloatLoc
1503             << "\n"
1504                "layout(location = CALLABLE_DATA_FLOAT_LOC) callableDataInEXT float callableDataFloat;\n"
1505                "void main() {\n"
1506                "   callableDataFloat /= 2.0f;\n"
1507                "}";
1508 
1509         addShaderSource<glu::CallableSource>(programCollection, "build-callable-2", src.str().c_str(), "", 0, "");
1510     }
1511     else
1512     {
1513         switch (params.invokingShader)
1514         {
1515         case glu::SHADERTYPE_RAYGEN:
1516             // Always defined since it's needed to invoke callable shaders that invoke other callable shaders
1517 
1518             break;
1519         case glu::SHADERTYPE_CLOSEST_HIT:
1520             addShaderSource<glu::ClosestHitSource>(
1521                 programCollection, "build-closesthit-invoke-callable", getClosestHitSource(true, false).c_str(),
1522                 "Result", DEFINE_RESULT_BUFFER,
1523                 "struct Result { float value0; float value1; float value2; float closestT;};");
1524 
1525             break;
1526         case glu::SHADERTYPE_MISS:
1527             addShaderSource<glu::MissSource>(
1528                 programCollection, "build-miss-invoke-callable", getMissSource(true, false).c_str(), "Result",
1529                 DEFINE_RESULT_BUFFER, "struct Result { float value0; float value1; float value2; float closestT;};");
1530 
1531             break;
1532         case glu::SHADERTYPE_CALLABLE:
1533             addShaderSource<glu::CallableSource>(
1534                 programCollection, "build-callable-invoke-callable", getCallableSource(true, false).c_str(), "Result",
1535                 DEFINE_RESULT_BUFFER, "struct Result { float value0; float value1; float value2; float closestT;};");
1536 
1537             break;
1538         default:
1539             TCU_THROW(InternalError, "Wrong shader invoking type");
1540             break;
1541         }
1542     }
1543 }
1544 
createInstance(Context & context) const1545 TestInstance *InvokeCallableShaderTestCase::createInstance(Context &context) const
1546 {
1547     return new InvokeCallableShaderTestInstance(context, params);
1548 }
1549 
InvokeCallableShaderTestInstance(Context & context,const TestParams & data)1550 InvokeCallableShaderTestInstance::InvokeCallableShaderTestInstance(Context &context, const TestParams &data)
1551     : vkt::TestInstance(context)
1552     , params(data)
1553 {
1554 }
1555 
~InvokeCallableShaderTestInstance(void)1556 InvokeCallableShaderTestInstance::~InvokeCallableShaderTestInstance(void)
1557 {
1558 }
1559 
iterate()1560 tcu::TestStatus InvokeCallableShaderTestInstance::iterate()
1561 {
1562     const VkDevice device        = m_context.getDevice();
1563     const DeviceInterface &vk    = m_context.getDeviceInterface();
1564     const InstanceInterface &vki = m_context.getInstanceInterface();
1565     Allocator &allocator         = m_context.getDefaultAllocator();
1566     de::MovePtr<RayTracingProperties> rayTracingProperties =
1567         makeRayTracingProperties(vki, m_context.getPhysicalDevice());
1568 
1569     vk::Move<VkDescriptorPool> descriptorPool;
1570     vk::Move<VkDescriptorSetLayout> descriptorSetLayout;
1571     std::vector<vk::Move<VkDescriptorSet>> descriptorSet;
1572     vk::Move<VkPipelineLayout> pipelineLayout;
1573 
1574     vk::DescriptorPoolBuilder descriptorPoolBuilder;
1575 
1576     uint32_t storageBufCount = 0;
1577 
1578     const VkDescriptorType accelType = VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR;
1579 
1580     storageBufCount += 1;
1581 
1582     storageBufCount += 1;
1583 
1584     descriptorPoolBuilder.addType(vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, storageBufCount);
1585 
1586     descriptorPoolBuilder.addType(accelType, 1);
1587 
1588     descriptorPool = descriptorPoolBuilder.build(vk, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1);
1589 
1590     vk::DescriptorSetLayoutBuilder setLayoutBuilder;
1591 
1592     const uint32_t AllStages = VK_SHADER_STAGE_RAYGEN_BIT_KHR | VK_SHADER_STAGE_MISS_BIT_KHR |
1593                                VK_SHADER_STAGE_INTERSECTION_BIT_KHR | VK_SHADER_STAGE_ANY_HIT_BIT_KHR |
1594                                VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_CALLABLE_BIT_KHR;
1595 
1596     setLayoutBuilder.addSingleBinding(vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, AllStages);
1597     setLayoutBuilder.addSingleBinding(accelType, AllStages);
1598     setLayoutBuilder.addSingleBinding(vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, AllStages);
1599 
1600     descriptorSetLayout = setLayoutBuilder.build(vk, device);
1601 
1602     const VkDescriptorSetAllocateInfo descriptorSetAllocateInfo = {
1603         VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO, // VkStructureType sType;
1604         DE_NULL,                                        // const void* pNext;
1605         *descriptorPool,                                // VkDescriptorPool descriptorPool;
1606         1u,                                             // uint32_t setLayoutCount;
1607         &descriptorSetLayout.get()                      // const VkDescriptorSetLayout* pSetLayouts;
1608     };
1609 
1610     descriptorSet.push_back(allocateDescriptorSet(vk, device, &descriptorSetAllocateInfo));
1611 
1612     const VkPipelineLayoutCreateInfo pipelineLayoutInfo = {
1613         VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, // VkStructureType sType;
1614         DE_NULL,                                       // const void* pNext;
1615         (VkPipelineLayoutCreateFlags)0,                // VkPipelineLayoutCreateFlags flags;
1616         1u,                                            // uint32_t setLayoutCount;
1617         &descriptorSetLayout.get(),                    // const VkDescriptorSetLayout* pSetLayouts;
1618         0u,                                            // uint32_t pushConstantRangeCount;
1619         nullptr,                                       // const VkPushConstantRange* pPushConstantRanges;
1620     };
1621 
1622     pipelineLayout = createPipelineLayout(vk, device, &pipelineLayoutInfo);
1623 
1624     std::string raygenId     = "build-raygen";
1625     std::string missId       = "build-miss";
1626     std::string closestHitId = "build-closesthit";
1627     std::vector<std::string> callableIds;
1628 
1629     switch (params.invokingShader)
1630     {
1631     case glu::SHADERTYPE_RAYGEN:
1632     {
1633         raygenId.append("-invoke-callable");
1634 
1635         if (params.multipleInvocations)
1636         {
1637             raygenId.append("-multi");
1638         }
1639         break;
1640     }
1641     case glu::SHADERTYPE_MISS:
1642     {
1643         missId.append("-invoke-callable");
1644 
1645         if (params.multipleInvocations)
1646         {
1647             missId.append("-multi");
1648         }
1649         break;
1650     }
1651     case glu::SHADERTYPE_CLOSEST_HIT:
1652     {
1653         closestHitId.append("-invoke-callable");
1654 
1655         if (params.multipleInvocations)
1656         {
1657             closestHitId.append("-multi");
1658         }
1659         break;
1660     }
1661     case glu::SHADERTYPE_CALLABLE:
1662     {
1663         raygenId.append("-invoke-callable");
1664         std::string callableId("build-callable-invoke-callable");
1665 
1666         if (params.multipleInvocations)
1667         {
1668             callableId.append("-multi");
1669         }
1670 
1671         callableIds.push_back(callableId);
1672         break;
1673     }
1674     default:
1675     {
1676         TCU_THROW(InternalError, "Wrong shader invoking type");
1677         break;
1678     }
1679     }
1680 
1681     callableIds.push_back("build-callable-0");
1682     if (params.multipleInvocations)
1683     {
1684         callableIds.push_back("build-callable-1");
1685         callableIds.push_back("build-callable-2");
1686     }
1687 
1688     de::MovePtr<RayTracingPipeline> rayTracingPipeline = de::newMovePtr<RayTracingPipeline>();
1689     rayTracingPipeline->addShader(
1690         VK_SHADER_STAGE_RAYGEN_BIT_KHR,
1691         createShaderModule(vk, device, m_context.getBinaryCollection().get(raygenId.c_str()), 0), 0);
1692     rayTracingPipeline->addShader(
1693         VK_SHADER_STAGE_MISS_BIT_KHR,
1694         createShaderModule(vk, device, m_context.getBinaryCollection().get(missId.c_str()), 0), 1);
1695     rayTracingPipeline->addShader(
1696         VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,
1697         createShaderModule(vk, device, m_context.getBinaryCollection().get(closestHitId.c_str()), 0), 2);
1698     uint32_t callableGroup = 3;
1699     for (auto &callableId : callableIds)
1700     {
1701         rayTracingPipeline->addShader(
1702             VK_SHADER_STAGE_CALLABLE_BIT_KHR,
1703             createShaderModule(vk, device, m_context.getBinaryCollection().get(callableId.c_str()), 0), callableGroup);
1704         ++callableGroup;
1705     }
1706     Move<VkPipeline> pipeline = rayTracingPipeline->createPipeline(vk, device, *pipelineLayout);
1707 
1708     CallableBuffer0 callableBuffer0 = {1, 4, 3, 7};
1709     CallableBuffer1 callableBuffer1 = {10.5, 2.5, 2};
1710 
1711     size_t MaxBufferSize              = std::max(sizeof(callableBuffer0), sizeof(callableBuffer1));
1712     uint32_t shaderGroupHandleSize    = rayTracingProperties->getShaderGroupHandleSize();
1713     uint32_t shaderGroupBaseAlignment = rayTracingProperties->getShaderGroupBaseAlignment();
1714     size_t shaderStride = deAlign32(shaderGroupHandleSize + (uint32_t)MaxBufferSize, shaderGroupHandleSize);
1715 
1716     de::MovePtr<BufferWithMemory> raygenShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
1717         vk, device, *pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
1718     de::MovePtr<BufferWithMemory> missShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
1719         vk, device, *pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
1720     de::MovePtr<BufferWithMemory> hitShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
1721         vk, device, *pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
1722     de::MovePtr<BufferWithMemory> callableShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
1723         vk, device, *pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 3,
1724         (uint32_t)callableIds.size(), 0U, 0U, MemoryRequirement::Any, 0U, 0U, (uint32_t)MaxBufferSize, nullptr, true);
1725 
1726     VkStridedDeviceAddressRegionKHR raygenShaderBindingTableRegion =
1727         makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vk, device, raygenShaderBindingTable->get(), 0),
1728                                           shaderGroupHandleSize, shaderGroupHandleSize);
1729     VkStridedDeviceAddressRegionKHR missShaderBindingTableRegion =
1730         makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vk, device, missShaderBindingTable->get(), 0),
1731                                           shaderGroupHandleSize, shaderGroupHandleSize);
1732     VkStridedDeviceAddressRegionKHR hitShaderBindingTableRegion =
1733         makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vk, device, hitShaderBindingTable->get(), 0),
1734                                           shaderGroupHandleSize, shaderGroupHandleSize);
1735     VkStridedDeviceAddressRegionKHR callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(
1736         getBufferDeviceAddress(vk, device, callableShaderBindingTable->get(), 0), shaderStride, shaderGroupHandleSize);
1737 
1738     size_t callableCount = 0;
1739 
1740     if (params.invokingShader == glu::SHADERTYPE_CALLABLE)
1741     {
1742         callableCount++;
1743     }
1744 
1745     deMemcpy((uint8_t *)callableShaderBindingTable->getAllocation().getHostPtr() + (shaderStride * (callableCount)) +
1746                  shaderGroupHandleSize,
1747              &callableBuffer0, sizeof(CallableBuffer0));
1748     callableCount++;
1749 
1750     if (params.multipleInvocations)
1751     {
1752         deMemcpy((uint8_t *)callableShaderBindingTable->getAllocation().getHostPtr() +
1753                      (shaderStride * (callableCount)) + shaderGroupHandleSize,
1754                  &callableBuffer1, sizeof(CallableBuffer1));
1755         callableCount++;
1756     }
1757 
1758     flushMappedMemoryRange(vk, device, callableShaderBindingTable->getAllocation().getMemory(),
1759                            callableShaderBindingTable->getAllocation().getOffset(), VK_WHOLE_SIZE);
1760 
1761     //                 {I}
1762     // (-2,1) (-1,1)  (0,1)  (1,1)  (2,1)
1763     //    X------X------X------X------X
1764     //    |\     |\     |\     |\     |
1765     //    | \ {B}| \ {D}| \ {F}| \ {H}|
1766     // {K}|  \   |  \   |  \   |  \   |{L}
1767     //    |   \  |   \  |   \  |   \  |
1768     //    |{A} \ |{C} \ |{E} \ |{G} \ |
1769     //    |     \|     \|     \|     \|
1770     //    X------X------X------X------X
1771     // (-2,-1)(-1,-1) (0,-1) (1,-1) (2,-1)
1772     //                 {J}
1773     //
1774     // A, B, E, and F are initially opaque
1775     // A and C are forced opaque
1776     // E and G are forced non-opaque
1777 
1778     std::vector<Ray> rays = {
1779         Ray{tcu::Vec3(-1.67f, -0.33f, 0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE}, // {A}
1780         Ray{tcu::Vec3(-1.33f, 0.33f, 0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE},  // {B}
1781         Ray{tcu::Vec3(-0.67f, -0.33f, 0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE}, // {C}
1782         Ray{tcu::Vec3(-0.33f, 0.33f, 0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE},  // {D}
1783         Ray{tcu::Vec3(0.33f, -0.33f, 0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE},  // {E}
1784         Ray{tcu::Vec3(0.67f, 0.33f, 0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE},   // {F}
1785         Ray{tcu::Vec3(1.33f, -0.33f, 0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE},  // {G}
1786         Ray{tcu::Vec3(1.67f, 0.33f, 0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE},   // {H}
1787         Ray{tcu::Vec3(0.0f, 1.01f, 0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE},    // {I}
1788         Ray{tcu::Vec3(0.0f, -1.01f, 0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE},   // {J}
1789         Ray{tcu::Vec3(-2.01f, 0.0f, 0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE},   // {K}
1790         Ray{tcu::Vec3(2.01f, 0.0f, 0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE}     // {L}
1791     };
1792 
1793     // B & F
1794     std::vector<tcu::Vec3> blas0VertsOpaque = {{-2.0f, 1.0f, 2.0f}, {-1.0f, -1.0f, 2.0f}, {-1.0f, 1.0f, 2.0f},
1795                                                {0.0f, 1.0f, 2.0f},  {1.0f, -1.0f, 2.0f},  {1.0f, 1.0f, 2.0f}};
1796 
1797     // D & H
1798     std::vector<tcu::Vec3> blas0VertsNoOpaque = {{-1.0f, 1.0f, 2.0f}, {0.0f, -1.0f, 2.0f}, {0.0f, 1.0f, 2.0f},
1799                                                  {1.0f, 1.0f, 2.0f},  {2.0f, -1.0f, 2.0f}, {2.0f, 1.0f, 2.0f}};
1800 
1801     // A
1802     std::vector<tcu::Vec3> blas1VertsOpaque = {{-2.0f, 1.0f, 2.0f}, {-2.0f, -1.0f, 2.0f}, {-1.0f, -1.0f, 2.0f}};
1803 
1804     // C
1805     std::vector<tcu::Vec3> blas1VertsNoOpaque = {{-1.0f, 1.0f, 2.0f}, {-1.0f, -1.0f, 2.0f}, {0.0f, -1.0f, 2.0f}};
1806 
1807     // E
1808     std::vector<tcu::Vec3> blas2VertsOpaque = {{0.0f, 1.0f, 2.0f}, {0.0f, -1.0f, 2.0f}, {1.0f, -1.0f, 2.0f}};
1809 
1810     // G
1811     std::vector<tcu::Vec3> blas2VertsNoOpaque = {{1.0f, 1.0f, 2.0f}, {1.0f, -1.0f, 2.0f}, {2.0f, -1.0f, 2.0f}};
1812 
1813     AddVertexLayers(&blas0VertsOpaque, 1);
1814     AddVertexLayers(&blas0VertsNoOpaque, 1);
1815     AddVertexLayers(&blas1VertsOpaque, 1);
1816     AddVertexLayers(&blas1VertsNoOpaque, 1);
1817     AddVertexLayers(&blas2VertsOpaque, 1);
1818     AddVertexLayers(&blas2VertsNoOpaque, 1);
1819 
1820     std::vector<tcu::Vec3> verts;
1821     verts.reserve(blas0VertsOpaque.size() + blas0VertsNoOpaque.size() + blas1VertsOpaque.size() +
1822                   blas1VertsNoOpaque.size() + blas2VertsOpaque.size() + blas2VertsNoOpaque.size());
1823     verts.insert(verts.end(), blas0VertsOpaque.begin(), blas0VertsOpaque.end());
1824     verts.insert(verts.end(), blas0VertsNoOpaque.begin(), blas0VertsNoOpaque.end());
1825     verts.insert(verts.end(), blas1VertsOpaque.begin(), blas1VertsOpaque.end());
1826     verts.insert(verts.end(), blas1VertsNoOpaque.begin(), blas1VertsNoOpaque.end());
1827     verts.insert(verts.end(), blas2VertsOpaque.begin(), blas2VertsOpaque.end());
1828     verts.insert(verts.end(), blas2VertsNoOpaque.begin(), blas2VertsNoOpaque.end());
1829 
1830     tcu::Surface resultImage(static_cast<int>(rays.size()), 1);
1831 
1832     const VkBufferCreateInfo resultBufferCreateInfo =
1833         makeBufferCreateInfo(rays.size() * sizeof(tcu::Vec4), VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
1834     de::MovePtr<BufferWithMemory> resultBuffer = de::MovePtr<BufferWithMemory>(
1835         new BufferWithMemory(vk, device, allocator, resultBufferCreateInfo, MemoryRequirement::HostVisible));
1836     const VkDescriptorBufferInfo resultDescriptorInfo = makeDescriptorBufferInfo(resultBuffer->get(), 0, VK_WHOLE_SIZE);
1837 
1838     const VkBufferCreateInfo rayBufferCreateInfo =
1839         makeBufferCreateInfo(rays.size() * sizeof(Ray), VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
1840     de::MovePtr<BufferWithMemory> rayBuffer = de::MovePtr<BufferWithMemory>(
1841         new BufferWithMemory(vk, device, allocator, rayBufferCreateInfo, MemoryRequirement::HostVisible));
1842     const VkDescriptorBufferInfo rayDescriptorInfo = makeDescriptorBufferInfo(rayBuffer->get(), 0, VK_WHOLE_SIZE);
1843     memcpy(rayBuffer->getAllocation().getHostPtr(), &rays[0], rays.size() * sizeof(Ray));
1844     flushMappedMemoryRange(vk, device, rayBuffer->getAllocation().getMemory(), rayBuffer->getAllocation().getOffset(),
1845                            VK_WHOLE_SIZE);
1846 
1847     const Move<VkCommandPool> cmdPool = createCommandPool(vk, device, 0, m_context.getUniversalQueueFamilyIndex());
1848     const Move<VkCommandBuffer> cmdBuffer =
1849         allocateCommandBuffer(vk, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
1850 
1851     beginCommandBuffer(vk, *cmdBuffer);
1852 
1853     de::SharedPtr<BottomLevelAccelerationStructure> blas0 =
1854         de::SharedPtr<BottomLevelAccelerationStructure>(makeBottomLevelAccelerationStructure().release());
1855     blas0->setGeometryCount(2);
1856     blas0->addGeometry(blas0VertsOpaque, true, VK_GEOMETRY_OPAQUE_BIT_KHR);
1857     blas0->addGeometry(blas0VertsNoOpaque, true, 0U);
1858     blas0->createAndBuild(vk, device, *cmdBuffer, allocator);
1859 
1860     de::SharedPtr<BottomLevelAccelerationStructure> blas1 =
1861         de::SharedPtr<BottomLevelAccelerationStructure>(makeBottomLevelAccelerationStructure().release());
1862     blas1->setGeometryCount(2);
1863     blas1->addGeometry(blas1VertsOpaque, true, VK_GEOMETRY_OPAQUE_BIT_KHR);
1864     blas1->addGeometry(blas1VertsNoOpaque, true, 0U);
1865     blas1->createAndBuild(vk, device, *cmdBuffer, allocator);
1866 
1867     de::SharedPtr<BottomLevelAccelerationStructure> blas2 =
1868         de::SharedPtr<BottomLevelAccelerationStructure>(makeBottomLevelAccelerationStructure().release());
1869     blas2->setGeometryCount(2);
1870     blas2->addGeometry(blas2VertsOpaque, true, VK_GEOMETRY_OPAQUE_BIT_KHR);
1871     blas2->addGeometry(blas2VertsNoOpaque, true, 0U);
1872     blas2->createAndBuild(vk, device, *cmdBuffer, allocator);
1873 
1874     de::MovePtr<TopLevelAccelerationStructure> tlas = makeTopLevelAccelerationStructure();
1875     tlas->setInstanceCount(3);
1876     tlas->addInstance(blas0);
1877     tlas->addInstance(blas1);
1878     tlas->addInstance(blas2);
1879     tlas->createAndBuild(vk, device, *cmdBuffer, allocator);
1880 
1881     VkWriteDescriptorSetAccelerationStructureKHR accelerationStructureWriteDescriptorSet = {
1882         VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR, //  VkStructureType sType;
1883         DE_NULL,                                                           //  const void* pNext;
1884         1u,                                                                //  uint32_t accelerationStructureCount;
1885         tlas->getPtr(), //  const VkAccelerationStructureKHR* pAccelerationStructures;
1886     };
1887 
1888     DescriptorSetUpdateBuilder()
1889         .writeSingle(*descriptorSet[0], DescriptorSetUpdateBuilder::Location::binding(DefaultResultBinding),
1890                      VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &resultDescriptorInfo)
1891         .writeSingle(*descriptorSet[0], DescriptorSetUpdateBuilder::Location::binding(DefaultSceneBinding),
1892                      VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &accelerationStructureWriteDescriptorSet)
1893         .writeSingle(*descriptorSet[0], DescriptorSetUpdateBuilder::Location::binding(DefaultRaysBinding),
1894                      VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &rayDescriptorInfo)
1895         .update(vk, device);
1896 
1897     vk.cmdBindPipeline(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipeline);
1898     vk.cmdBindDescriptorSets(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipelineLayout, 0, 1,
1899                              &descriptorSet[0].get(), 0, DE_NULL);
1900 
1901     cmdTraceRays(vk, *cmdBuffer, &raygenShaderBindingTableRegion, &missShaderBindingTableRegion,
1902                  &hitShaderBindingTableRegion, &callableShaderBindingTableRegion, static_cast<uint32_t>(rays.size()), 1,
1903                  1);
1904 
1905     endCommandBuffer(vk, *cmdBuffer);
1906 
1907     submitCommandsAndWait(vk, device, m_context.getUniversalQueue(), *cmdBuffer);
1908 
1909     invalidateMappedMemoryRange(vk, device, resultBuffer->getAllocation().getMemory(),
1910                                 resultBuffer->getAllocation().getOffset(), VK_WHOLE_SIZE);
1911 
1912     //                 {I}
1913     // (-2,1) (-1,1)  (0,1)  (1,1)  (2,1)
1914     //    X------X------X------X------X
1915     //    |\     |\     |\     |\     |
1916     //    | \ {B}| \ {D}| \ {F}| \ {H}|
1917     // {K}|  \   |  \   |  \   |  \   |{L}
1918     //    |   \  |   \  |   \  |   \  |
1919     //    |{A} \ |{C} \ |{E} \ |{G} \ |
1920     //    |     \|     \|     \|     \|
1921     //    X------X------X------X------X
1922     // (-2,-1)(-1,-1) (0,-1) (1,-1) (2,-1)
1923     //                 {J}
1924     // A, B, E, and F are opaque
1925     // A and C are forced opaque
1926     // E and G are forced non-opaque
1927 
1928     std::vector<bool> hits    = {true, true, true, true, true, true, true, true, false, false, false, false};
1929     std::vector<bool> opaques = {true, true, true, false, false, true, false, false, true, true, true, true};
1930 
1931     union
1932     {
1933         bool mismatch[32];
1934         uint32_t mismatchAll;
1935     };
1936     mismatchAll = 0;
1937 
1938     tcu::Vec4 *resultData = (tcu::Vec4 *)resultBuffer->getAllocation().getHostPtr();
1939 
1940     for (int index = 0; index < resultImage.getWidth(); ++index)
1941     {
1942         if (verifyResultData(&resultData[index], index, hits[index], params))
1943         {
1944             resultImage.setPixel(index, 0, tcu::RGBA(255, 0, 0, 255));
1945         }
1946         else
1947         {
1948             mismatch[index] = true;
1949             resultImage.setPixel(index, 0, tcu::RGBA(0, 0, 0, 255));
1950         }
1951     }
1952 
1953     // Write Image
1954     m_context.getTestContext().getLog() << tcu::TestLog::ImageSet("Result of rendering", "Result of rendering")
1955                                         << tcu::TestLog::Image("Result", "Result", resultImage)
1956                                         << tcu::TestLog::EndImageSet;
1957 
1958     if (mismatchAll != 0)
1959         TCU_FAIL("Result data did not match expected output");
1960 
1961     return tcu::TestStatus::pass("pass");
1962 }
1963 
1964 } // namespace
1965 
createCallableShadersTests(tcu::TestContext & testCtx)1966 tcu::TestCaseGroup *createCallableShadersTests(tcu::TestContext &testCtx)
1967 {
1968     // Tests veryfying callable shaders
1969     de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "callable_shader"));
1970 
1971     struct CallableShaderTestTypeData
1972     {
1973         CallableShaderTestType shaderTestType;
1974         const char *name;
1975     } callableShaderTestTypes[] = {
1976         {CSTT_RGEN_CALL, "rgen_call"},
1977         {CSTT_RGEN_CALL_CALL, "rgen_call_call"},
1978         {CSTT_HIT_CALL, "hit_call"},
1979         {CSTT_RGEN_MULTICALL, "rgen_multicall"},
1980     };
1981 
1982     for (size_t shaderTestNdx = 0; shaderTestNdx < DE_LENGTH_OF_ARRAY(callableShaderTestTypes); ++shaderTestNdx)
1983     {
1984         TestParams testParams{TEST_WIDTH,
1985                               TEST_HEIGHT,
1986                               callableShaderTestTypes[shaderTestNdx].shaderTestType,
1987                               de::SharedPtr<TestConfiguration>(new SingleSquareConfiguration()),
1988                               glu::SHADERTYPE_LAST,
1989                               false};
1990         group->addChild(new CallableShaderTestCase(group->getTestContext(), callableShaderTestTypes[shaderTestNdx].name,
1991                                                    testParams));
1992     }
1993 
1994     bool multipleInvocations[]            = {false, true};
1995     std::string multipleInvocationsText[] = {"_single_invocation", "_multiple_invocations"};
1996     // Callable shaders cannot be called from any-hit shader per GLSL_NV_ray_tracing spec. Assuming same will hold for KHR version.
1997     glu::ShaderType invokingShaders[] = {glu::SHADERTYPE_RAYGEN, glu::SHADERTYPE_CALLABLE, glu::SHADERTYPE_CLOSEST_HIT,
1998                                          glu::SHADERTYPE_MISS};
1999     std::string invokingShadersText[] = {"_invoked_via_raygen", "_invoked_via_callable", "_invoked_via_closest_hit",
2000                                          "_invoked_via_miss"};
2001 
2002     for (int j = 0; j < 4; ++j)
2003     {
2004         TestParams params;
2005 
2006         std::string name("callable_shader");
2007 
2008         params.invokingShader = invokingShaders[j];
2009         name.append(invokingShadersText[j]);
2010 
2011         for (int k = 0; k < 2; ++k)
2012         {
2013             std::string nameFull(name);
2014 
2015             params.multipleInvocations = multipleInvocations[k];
2016             nameFull.append(multipleInvocationsText[k]);
2017 
2018             group->addChild(new InvokeCallableShaderTestCase(group->getTestContext(), nameFull.c_str(), params));
2019         }
2020     }
2021 
2022     return group.release();
2023 }
2024 
2025 } // namespace RayTracing
2026 
2027 } // namespace vkt
2028