1 /*-------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2020 The Khronos Group Inc.
6  * Copyright (c) 2020 Valve Corporation.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  *//*!
21  * \file
22  * \brief Ray Tracing Data Spill tests
23  *//*--------------------------------------------------------------------*/
24 #include "vktRayTracingDataSpillTests.hpp"
25 #include "vktTestCase.hpp"
26 
27 #include "vkRayTracingUtil.hpp"
28 #include "vkObjUtil.hpp"
29 #include "vkBufferWithMemory.hpp"
30 #include "vkImageWithMemory.hpp"
31 #include "vkBuilderUtil.hpp"
32 #include "vkCmdUtil.hpp"
33 #include "vkTypeUtil.hpp"
34 #include "vkBarrierUtil.hpp"
35 
36 #include "tcuStringTemplate.hpp"
37 #include "tcuFloat.hpp"
38 
39 #include "deUniquePtr.hpp"
40 #include "deSTLUtil.hpp"
41 
42 #include <sstream>
43 #include <string>
44 #include <map>
45 #include <vector>
46 #include <array>
47 #include <utility>
48 
49 using namespace vk;
50 
51 namespace vkt
52 {
53 namespace RayTracing
54 {
55 
56 namespace
57 {
58 
59 // The type of shader call that will be used.
60 enum class CallType
61 {
62     TRACE_RAY = 0,
63     EXECUTE_CALLABLE,
64     REPORT_INTERSECTION,
65 };
66 
67 // The type of data that will be checked.
68 enum class DataType
69 {
70     // These can be made an array or vector.
71     INT32 = 0,
72     UINT32,
73     INT64,
74     UINT64,
75     INT16,
76     UINT16,
77     INT8,
78     UINT8,
79     FLOAT32,
80     FLOAT64,
81     FLOAT16,
82 
83     // These are standalone, so the vector type should be scalar.
84     STRUCT,
85     IMAGE,
86     SAMPLER,
87     SAMPLED_IMAGE,
88     PTR_IMAGE,
89     PTR_SAMPLER,
90     PTR_SAMPLED_IMAGE,
91     PTR_TEXEL,
92     OP_NULL,
93     OP_UNDEF,
94 };
95 
96 // The type of vector in use.
97 enum class VectorType
98 {
99     SCALAR = 1,
100     V2     = 2,
101     V3     = 3,
102     V4     = 4,
103     A5     = 5,
104 };
105 
106 struct InputStruct
107 {
108     uint32_t uintPart;
109     float floatPart;
110 };
111 
112 constexpr auto kImageFormat = VK_FORMAT_R32_UINT;
113 const auto kImageExtent     = makeExtent3D(1u, 1u, 1u);
114 
115 // For samplers.
116 const VkImageUsageFlags kSampledImageUsage = (VK_IMAGE_USAGE_TRANSFER_DST_BIT | VK_IMAGE_USAGE_SAMPLED_BIT);
117 constexpr size_t kNumImages                = 4u;
118 constexpr size_t kNumSamplers              = 4u;
119 constexpr size_t kNumCombined              = 2u;
120 constexpr size_t kNumAloneImages           = kNumImages - kNumCombined;
121 constexpr size_t kNumAloneSamplers         = kNumSamplers - kNumCombined;
122 
123 // For storage images.
124 const VkImageUsageFlags kStorageImageUsage = (VK_IMAGE_USAGE_TRANSFER_DST_BIT | VK_IMAGE_USAGE_STORAGE_BIT);
125 
126 // For the pipeline interface tests.
127 constexpr size_t kNumStorageValues   = 6u;
128 constexpr uint32_t kShaderRecordSize = sizeof(tcu::UVec4);
129 
130 // Get the effective vector length in memory.
getEffectiveVectorLength(VectorType vectorType)131 size_t getEffectiveVectorLength(VectorType vectorType)
132 {
133     return ((vectorType == VectorType::V3) ? static_cast<size_t>(4) : static_cast<size_t>(vectorType));
134 }
135 
136 // Get the corresponding element size.
getElementSize(DataType dataType,VectorType vectorType)137 VkDeviceSize getElementSize(DataType dataType, VectorType vectorType)
138 {
139     const size_t length = getEffectiveVectorLength(vectorType);
140     size_t dataSize     = 0u;
141 
142     switch (dataType)
143     {
144     case DataType::INT32:
145         dataSize = sizeof(int32_t);
146         break;
147     case DataType::UINT32:
148         dataSize = sizeof(uint32_t);
149         break;
150     case DataType::INT64:
151         dataSize = sizeof(int64_t);
152         break;
153     case DataType::UINT64:
154         dataSize = sizeof(uint64_t);
155         break;
156     case DataType::INT16:
157         dataSize = sizeof(int16_t);
158         break;
159     case DataType::UINT16:
160         dataSize = sizeof(uint16_t);
161         break;
162     case DataType::INT8:
163         dataSize = sizeof(int8_t);
164         break;
165     case DataType::UINT8:
166         dataSize = sizeof(uint8_t);
167         break;
168     case DataType::FLOAT32:
169         dataSize = sizeof(tcu::Float32);
170         break;
171     case DataType::FLOAT64:
172         dataSize = sizeof(tcu::Float64);
173         break;
174     case DataType::FLOAT16:
175         dataSize = sizeof(tcu::Float16);
176         break;
177     case DataType::STRUCT:
178         dataSize = sizeof(InputStruct);
179         break;
180     case DataType::IMAGE:             // fallthrough.
181     case DataType::SAMPLER:           // fallthrough.
182     case DataType::SAMPLED_IMAGE:     // fallthrough.
183     case DataType::PTR_IMAGE:         // fallthrough.
184     case DataType::PTR_SAMPLER:       // fallthrough.
185     case DataType::PTR_SAMPLED_IMAGE: // fallthrough.
186         dataSize = sizeof(tcu::Float32);
187         break;
188     case DataType::PTR_TEXEL:
189         dataSize = sizeof(int32_t);
190         break;
191     case DataType::OP_NULL:  // fallthrough.
192     case DataType::OP_UNDEF: // fallthrough.
193         dataSize = sizeof(uint32_t);
194         break;
195     default:
196         DE_ASSERT(false);
197         break;
198     }
199 
200     return static_cast<VkDeviceSize>(dataSize * length);
201 }
202 
203 // Proper stage for generating default geometry.
getShaderStageForGeometry(CallType type_)204 VkShaderStageFlagBits getShaderStageForGeometry(CallType type_)
205 {
206     VkShaderStageFlagBits bits = VK_SHADER_STAGE_FLAG_BITS_MAX_ENUM;
207 
208     switch (type_)
209     {
210     case CallType::TRACE_RAY:
211         bits = VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
212         break;
213     case CallType::EXECUTE_CALLABLE:
214         bits = VK_SHADER_STAGE_CALLABLE_BIT_KHR;
215         break;
216     case CallType::REPORT_INTERSECTION:
217         bits = VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
218         break;
219     default:
220         DE_ASSERT(false);
221         break;
222     }
223 
224     DE_ASSERT(bits != VK_SHADER_STAGE_FLAG_BITS_MAX_ENUM);
225     return bits;
226 }
227 
getShaderStages(CallType type_)228 VkShaderStageFlags getShaderStages(CallType type_)
229 {
230     VkShaderStageFlags flags = VK_SHADER_STAGE_RAYGEN_BIT_KHR;
231 
232     switch (type_)
233     {
234     case CallType::EXECUTE_CALLABLE:
235         flags |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
236         break;
237     case CallType::TRACE_RAY:
238         flags |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
239         break;
240     case CallType::REPORT_INTERSECTION:
241         flags |= VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
242         flags |= VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
243         break;
244     default:
245         DE_ASSERT(false);
246         break;
247     }
248 
249     return flags;
250 }
251 
252 // Some test types need additional descriptors with samplers, images and combined image samplers.
samplersNeeded(DataType dataType)253 bool samplersNeeded(DataType dataType)
254 {
255     bool needed = false;
256 
257     switch (dataType)
258     {
259     case DataType::IMAGE:
260     case DataType::SAMPLER:
261     case DataType::SAMPLED_IMAGE:
262     case DataType::PTR_IMAGE:
263     case DataType::PTR_SAMPLER:
264     case DataType::PTR_SAMPLED_IMAGE:
265         needed = true;
266         break;
267     default:
268         break;
269     }
270 
271     return needed;
272 }
273 
274 // Some test types need an additional descriptor with a storage image.
storageImageNeeded(DataType dataType)275 bool storageImageNeeded(DataType dataType)
276 {
277     return (dataType == DataType::PTR_TEXEL);
278 }
279 
280 // Returns two strings:
281 //        .first is an optional GLSL additional type declaration (for structs, basically).
282 //        .second is the value declaration inside the input block.
getGLSLInputValDecl(DataType dataType,VectorType vectorType)283 std::pair<std::string, std::string> getGLSLInputValDecl(DataType dataType, VectorType vectorType)
284 {
285     using TypePair = std::pair<DataType, VectorType>;
286     using TypeMap  = std::map<TypePair, std::string>;
287 
288     const std::string varName = "val";
289     const auto dataTypeIdx    = static_cast<int>(dataType);
290 
291     if (dataTypeIdx >= static_cast<int>(DataType::INT32) && dataTypeIdx <= static_cast<int>(DataType::FLOAT16))
292     {
293         // Note: A5 uses the same type as the scalar version. The array suffix will be added below.
294         const TypeMap map = {
295             std::make_pair(std::make_pair(DataType::INT32, VectorType::SCALAR), "int32_t"),
296             std::make_pair(std::make_pair(DataType::INT32, VectorType::V2), "i32vec2"),
297             std::make_pair(std::make_pair(DataType::INT32, VectorType::V3), "i32vec3"),
298             std::make_pair(std::make_pair(DataType::INT32, VectorType::V4), "i32vec4"),
299             std::make_pair(std::make_pair(DataType::INT32, VectorType::A5), "int32_t"),
300             std::make_pair(std::make_pair(DataType::UINT32, VectorType::SCALAR), "uint32_t"),
301             std::make_pair(std::make_pair(DataType::UINT32, VectorType::V2), "u32vec2"),
302             std::make_pair(std::make_pair(DataType::UINT32, VectorType::V3), "u32vec3"),
303             std::make_pair(std::make_pair(DataType::UINT32, VectorType::V4), "u32vec4"),
304             std::make_pair(std::make_pair(DataType::UINT32, VectorType::A5), "uint32_t"),
305             std::make_pair(std::make_pair(DataType::INT64, VectorType::SCALAR), "int64_t"),
306             std::make_pair(std::make_pair(DataType::INT64, VectorType::V2), "i64vec2"),
307             std::make_pair(std::make_pair(DataType::INT64, VectorType::V3), "i64vec3"),
308             std::make_pair(std::make_pair(DataType::INT64, VectorType::V4), "i64vec4"),
309             std::make_pair(std::make_pair(DataType::INT64, VectorType::A5), "int64_t"),
310             std::make_pair(std::make_pair(DataType::UINT64, VectorType::SCALAR), "uint64_t"),
311             std::make_pair(std::make_pair(DataType::UINT64, VectorType::V2), "u64vec2"),
312             std::make_pair(std::make_pair(DataType::UINT64, VectorType::V3), "u64vec3"),
313             std::make_pair(std::make_pair(DataType::UINT64, VectorType::V4), "u64vec4"),
314             std::make_pair(std::make_pair(DataType::UINT64, VectorType::A5), "uint64_t"),
315             std::make_pair(std::make_pair(DataType::INT16, VectorType::SCALAR), "int16_t"),
316             std::make_pair(std::make_pair(DataType::INT16, VectorType::V2), "i16vec2"),
317             std::make_pair(std::make_pair(DataType::INT16, VectorType::V3), "i16vec3"),
318             std::make_pair(std::make_pair(DataType::INT16, VectorType::V4), "i16vec4"),
319             std::make_pair(std::make_pair(DataType::INT16, VectorType::A5), "int16_t"),
320             std::make_pair(std::make_pair(DataType::UINT16, VectorType::SCALAR), "uint16_t"),
321             std::make_pair(std::make_pair(DataType::UINT16, VectorType::V2), "u16vec2"),
322             std::make_pair(std::make_pair(DataType::UINT16, VectorType::V3), "u16vec3"),
323             std::make_pair(std::make_pair(DataType::UINT16, VectorType::V4), "u16vec4"),
324             std::make_pair(std::make_pair(DataType::UINT16, VectorType::A5), "uint16_t"),
325             std::make_pair(std::make_pair(DataType::INT8, VectorType::SCALAR), "int8_t"),
326             std::make_pair(std::make_pair(DataType::INT8, VectorType::V2), "i8vec2"),
327             std::make_pair(std::make_pair(DataType::INT8, VectorType::V3), "i8vec3"),
328             std::make_pair(std::make_pair(DataType::INT8, VectorType::V4), "i8vec4"),
329             std::make_pair(std::make_pair(DataType::INT8, VectorType::A5), "int8_t"),
330             std::make_pair(std::make_pair(DataType::UINT8, VectorType::SCALAR), "uint8_t"),
331             std::make_pair(std::make_pair(DataType::UINT8, VectorType::V2), "u8vec2"),
332             std::make_pair(std::make_pair(DataType::UINT8, VectorType::V3), "u8vec3"),
333             std::make_pair(std::make_pair(DataType::UINT8, VectorType::V4), "u8vec4"),
334             std::make_pair(std::make_pair(DataType::UINT8, VectorType::A5), "uint8_t"),
335             std::make_pair(std::make_pair(DataType::FLOAT32, VectorType::SCALAR), "float32_t"),
336             std::make_pair(std::make_pair(DataType::FLOAT32, VectorType::V2), "f32vec2"),
337             std::make_pair(std::make_pair(DataType::FLOAT32, VectorType::V3), "f32vec3"),
338             std::make_pair(std::make_pair(DataType::FLOAT32, VectorType::V4), "f32vec4"),
339             std::make_pair(std::make_pair(DataType::FLOAT32, VectorType::A5), "float32_t"),
340             std::make_pair(std::make_pair(DataType::FLOAT64, VectorType::SCALAR), "float64_t"),
341             std::make_pair(std::make_pair(DataType::FLOAT64, VectorType::V2), "f64vec2"),
342             std::make_pair(std::make_pair(DataType::FLOAT64, VectorType::V3), "f64vec3"),
343             std::make_pair(std::make_pair(DataType::FLOAT64, VectorType::V4), "f64vec4"),
344             std::make_pair(std::make_pair(DataType::FLOAT64, VectorType::A5), "float64_t"),
345             std::make_pair(std::make_pair(DataType::FLOAT16, VectorType::SCALAR), "float16_t"),
346             std::make_pair(std::make_pair(DataType::FLOAT16, VectorType::V2), "f16vec2"),
347             std::make_pair(std::make_pair(DataType::FLOAT16, VectorType::V3), "f16vec3"),
348             std::make_pair(std::make_pair(DataType::FLOAT16, VectorType::V4), "f16vec4"),
349             std::make_pair(std::make_pair(DataType::FLOAT16, VectorType::A5), "float16_t"),
350         };
351 
352         const auto key   = std::make_pair(dataType, vectorType);
353         const auto found = map.find(key);
354 
355         DE_ASSERT(found != end(map));
356 
357         const auto baseType    = found->second;
358         const std::string decl = baseType + " " + varName + ((vectorType == VectorType::A5) ? "[5]" : "") + ";";
359 
360         return std::make_pair(std::string(), decl);
361     }
362     else if (dataType == DataType::STRUCT)
363     {
364         return std::make_pair(std::string("struct InputStruct { uint val1; float val2; };\n"),
365                               std::string("InputStruct val;"));
366     }
367     else if (samplersNeeded(dataType))
368     {
369         return std::make_pair(std::string(), std::string("float val;"));
370     }
371     else if (storageImageNeeded(dataType))
372     {
373         return std::make_pair(std::string(), std::string("int val;"));
374     }
375     else if (dataType == DataType::OP_NULL || dataType == DataType::OP_UNDEF)
376     {
377         return std::make_pair(std::string(), std::string("uint val;"));
378     }
379 
380     // Unreachable.
381     DE_ASSERT(false);
382     return std::make_pair(std::string(), std::string());
383 }
384 
385 class DataSpillTestCase : public vkt::TestCase
386 {
387 public:
388     struct TestParams
389     {
390         CallType callType;
391         DataType dataType;
392         VectorType vectorType;
393     };
394 
395     DataSpillTestCase(tcu::TestContext &testCtx, const std::string &name, const TestParams &testParams);
~DataSpillTestCase(void)396     virtual ~DataSpillTestCase(void)
397     {
398     }
399 
400     virtual void initPrograms(vk::SourceCollections &programCollection) const;
401     virtual TestInstance *createInstance(Context &context) const;
402     virtual void checkSupport(Context &context) const;
403 
404 private:
405     TestParams m_params;
406 };
407 
408 class DataSpillTestInstance : public vkt::TestInstance
409 {
410 public:
411     using TestParams = DataSpillTestCase::TestParams;
412 
413     DataSpillTestInstance(Context &context, const TestParams &testParams);
~DataSpillTestInstance(void)414     virtual ~DataSpillTestInstance(void)
415     {
416     }
417 
418     virtual tcu::TestStatus iterate(void);
419 
420 private:
421     TestParams m_params;
422 };
423 
DataSpillTestCase(tcu::TestContext & testCtx,const std::string & name,const TestParams & testParams)424 DataSpillTestCase::DataSpillTestCase(tcu::TestContext &testCtx, const std::string &name, const TestParams &testParams)
425     : vkt::TestCase(testCtx, name)
426     , m_params(testParams)
427 {
428     switch (m_params.dataType)
429     {
430     case DataType::STRUCT:
431     case DataType::IMAGE:
432     case DataType::SAMPLER:
433     case DataType::SAMPLED_IMAGE:
434     case DataType::PTR_IMAGE:
435     case DataType::PTR_SAMPLER:
436     case DataType::PTR_SAMPLED_IMAGE:
437     case DataType::PTR_TEXEL:
438     case DataType::OP_NULL:
439     case DataType::OP_UNDEF:
440         DE_ASSERT(m_params.vectorType == VectorType::SCALAR);
441         break;
442     default:
443         break;
444     }
445 
446     // The code assumes at most one of these is needed.
447     DE_ASSERT(!(samplersNeeded(m_params.dataType) && storageImageNeeded(m_params.dataType)));
448 }
449 
createInstance(Context & context) const450 TestInstance *DataSpillTestCase::createInstance(Context &context) const
451 {
452     return new DataSpillTestInstance(context, m_params);
453 }
454 
DataSpillTestInstance(Context & context,const TestParams & testParams)455 DataSpillTestInstance::DataSpillTestInstance(Context &context, const TestParams &testParams)
456     : vkt::TestInstance(context)
457     , m_params(testParams)
458 {
459 }
460 
461 // General checks for all tests.
commonCheckSupport(Context & context)462 void commonCheckSupport(Context &context)
463 {
464     context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
465     context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
466 
467     const auto &rtFeatures = context.getRayTracingPipelineFeatures();
468     if (!rtFeatures.rayTracingPipeline)
469         TCU_THROW(NotSupportedError, "Ray Tracing pipelines not supported");
470 
471     const auto &asFeatures = context.getAccelerationStructureFeatures();
472     if (!asFeatures.accelerationStructure)
473         TCU_FAIL("VK_KHR_acceleration_structure supported without accelerationStructure support");
474 }
475 
checkSupport(Context & context) const476 void DataSpillTestCase::checkSupport(Context &context) const
477 {
478     // General checks first.
479     commonCheckSupport(context);
480 
481     const auto &features          = context.getDeviceFeatures();
482     const auto &featuresStorage16 = context.get16BitStorageFeatures();
483     const auto &featuresF16I8     = context.getShaderFloat16Int8Features();
484     const auto &featuresStorage8  = context.get8BitStorageFeatures();
485 
486     if (m_params.dataType == DataType::INT64 || m_params.dataType == DataType::UINT64)
487     {
488         if (!features.shaderInt64)
489             TCU_THROW(NotSupportedError, "64-bit integers not supported");
490     }
491     else if (m_params.dataType == DataType::INT16 || m_params.dataType == DataType::UINT16)
492     {
493         context.requireDeviceFunctionality("VK_KHR_16bit_storage");
494 
495         if (!features.shaderInt16)
496             TCU_THROW(NotSupportedError, "16-bit integers not supported");
497 
498         if (!featuresStorage16.storageBuffer16BitAccess)
499             TCU_THROW(NotSupportedError, "16-bit storage buffer access not supported");
500     }
501     else if (m_params.dataType == DataType::INT8 || m_params.dataType == DataType::UINT8)
502     {
503         context.requireDeviceFunctionality("VK_KHR_shader_float16_int8");
504         context.requireDeviceFunctionality("VK_KHR_8bit_storage");
505 
506         if (!featuresF16I8.shaderInt8)
507             TCU_THROW(NotSupportedError, "8-bit integers not supported");
508 
509         if (!featuresStorage8.storageBuffer8BitAccess)
510             TCU_THROW(NotSupportedError, "8-bit storage buffer access not supported");
511     }
512     else if (m_params.dataType == DataType::FLOAT64)
513     {
514         if (!features.shaderFloat64)
515             TCU_THROW(NotSupportedError, "64-bit floats not supported");
516     }
517     else if (m_params.dataType == DataType::FLOAT16)
518     {
519         context.requireDeviceFunctionality("VK_KHR_shader_float16_int8");
520         context.requireDeviceFunctionality("VK_KHR_16bit_storage");
521 
522         if (!featuresF16I8.shaderFloat16)
523             TCU_THROW(NotSupportedError, "16-bit floats not supported");
524 
525         if (!featuresStorage16.storageBuffer16BitAccess)
526             TCU_THROW(NotSupportedError, "16-bit storage buffer access not supported");
527     }
528     else if (samplersNeeded(m_params.dataType))
529     {
530         context.requireDeviceFunctionality("VK_EXT_descriptor_indexing");
531         const auto indexingFeatures = context.getDescriptorIndexingFeatures();
532         if (!indexingFeatures.shaderSampledImageArrayNonUniformIndexing)
533             TCU_THROW(NotSupportedError, "No support for non-uniform sampled image arrays");
534     }
535 }
536 
initPrograms(vk::SourceCollections & programCollection) const537 void DataSpillTestCase::initPrograms(vk::SourceCollections &programCollection) const
538 {
539     const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
540     const vk::SpirVAsmBuildOptions spvBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, true);
541 
542     std::ostringstream spvTemplateStream;
543 
544     // This SPIR-V template will be used to generate shaders for different
545     // stages (raygen, callable, etc). The basic mechanism uses 3 SSBOs: one
546     // used strictly as an input, one to write the check result, and one to
547     // verify the shader call has taken place. The latter two SSBOs contain just
548     // a single uint, but the input SSBO typically contains other type of data
549     // that will be filled from the test instance with predetermined values. The
550     // shader will expect this data to have specific values that can be combined
551     // some way to give an expected result (e.g. by adding the 4 components if
552     // it's a vec4). This result will be used in the shader call to make sure
553     // input values are read *before* the call. After the shader call has taken
554     // place, the shader will attempt to read the input buffer again and verify
555     // the value is still correct and matches the previous one. If the result
556     // matches, it will write a confirmation value in the check buffer. In the
557     // mean time, the callee will write a confirmation value in the callee
558     // buffer to verify the shader call took place.
559     //
560     // Some test variants use samplers, images or sampled images. These need
561     // additional bindings of different types and the interesting value is
562     // typically placed in the image instead of the input buffer, while the
563     // input buffer is used for sampling coordinates instead.
564     //
565     // Some important SPIR-V template variables:
566     //
567     // - INPUT_BUFFER_VALUE_TYPE will contain the type of input buffer data.
568     // - CALC_ZERO_FOR_CALLABLE is expected to contain instructions that will
569     //   calculate a value of zero to be used in the shader call instruction.
570     //   This value should be derived from the input data.
571     // - CALL_STATEMENTS will contain the shader call instructions.
572     // - CALC_EQUAL_STATEMENT is expected to contain instructions that will
573     //   set %equal to true as a %bool if the before- and after- data match.
574     //
575     // - %input_val_ptr contains the pointer to the input value.
576     // - %input_val_before contains the value read before the call.
577     // - %input_val_after contains the value read after the call.
578 
579     spvTemplateStream
580         << "                                  OpCapability RayTracingKHR\n"
581         << "${EXTRA_CAPABILITIES}"
582         << "                                  OpExtension \"SPV_KHR_ray_tracing\"\n"
583         << "${EXTRA_EXTENSIONS}"
584         << "                                  OpMemoryModel Logical GLSL450\n"
585         << "                                  OpEntryPoint ${ENTRY_POINT} %main \"main\" %topLevelAS %calleeBuffer "
586            "%outputBuffer %inputBuffer${MAIN_INTERFACE_EXTRAS}\n"
587         << "${INTERFACE_DECORATIONS}"
588         << "                                  OpMemberDecorate %InputBlock 0 Offset 0\n"
589         << "                                  OpDecorate %InputBlock Block\n"
590         << "                                  OpDecorate %inputBuffer DescriptorSet 0\n"
591         << "                                  OpDecorate %inputBuffer Binding 3\n"
592         << "                                  OpMemberDecorate %OutputBlock 0 Offset 0\n"
593         << "                                  OpDecorate %OutputBlock Block\n"
594         << "                                  OpDecorate %outputBuffer DescriptorSet 0\n"
595         << "                                  OpDecorate %outputBuffer Binding 2\n"
596         << "                                  OpMemberDecorate %CalleeBlock 0 Offset 0\n"
597         << "                                  OpDecorate %CalleeBlock Block\n"
598         << "                                  OpDecorate %calleeBuffer DescriptorSet 0\n"
599         << "                                  OpDecorate %calleeBuffer Binding 1\n"
600         << "                                  OpDecorate %topLevelAS DescriptorSet 0\n"
601         << "                                  OpDecorate %topLevelAS Binding 0\n"
602         << "${EXTRA_BINDINGS}"
603         << "                          %void = OpTypeVoid\n"
604         << "                     %void_func = OpTypeFunction %void\n"
605         << "                           %int = OpTypeInt 32 1\n"
606         << "                          %uint = OpTypeInt 32 0\n"
607         << "                         %int_0 = OpConstant %int 0\n"
608         << "                        %uint_0 = OpConstant %uint 0\n"
609         << "                        %uint_1 = OpConstant %uint 1\n"
610         << "                        %uint_2 = OpConstant %uint 2\n"
611         << "                        %uint_3 = OpConstant %uint 3\n"
612         << "                        %uint_4 = OpConstant %uint 4\n"
613         << "                        %uint_5 = OpConstant %uint 5\n"
614         << "                      %uint_255 = OpConstant %uint 255\n"
615         << "                          %bool = OpTypeBool\n"
616         << "                         %float = OpTypeFloat 32\n"
617         << "                       %float_0 = OpConstant %float 0\n"
618         << "                       %float_1 = OpConstant %float 1\n"
619         << "                       %float_9 = OpConstant %float 9\n"
620         << "                     %float_0_5 = OpConstant %float 0.5\n"
621         << "                      %float_n1 = OpConstant %float -1\n"
622         << "                       %v3float = OpTypeVector %float 3\n"
623         << "                  %origin_const = OpConstantComposite %v3float %float_0_5 %float_0_5 %float_0\n"
624         << "               %direction_const = OpConstantComposite %v3float %float_0 %float_0 %float_n1\n"
625         << "${EXTRA_TYPES_AND_CONSTANTS}"
626         << "                 %data_func_ptr = OpTypePointer Function ${INPUT_BUFFER_VALUE_TYPE}\n"
627         << "${INTERFACE_TYPES_AND_VARIABLES}"
628         << "                    %InputBlock = OpTypeStruct ${INPUT_BUFFER_VALUE_TYPE}\n"
629         << " %_ptr_StorageBuffer_InputBlock = OpTypePointer StorageBuffer %InputBlock\n"
630         << "                   %inputBuffer = OpVariable %_ptr_StorageBuffer_InputBlock StorageBuffer\n"
631         << "        %data_storagebuffer_ptr = OpTypePointer StorageBuffer ${INPUT_BUFFER_VALUE_TYPE}\n"
632         << "                   %OutputBlock = OpTypeStruct %uint\n"
633         << "%_ptr_StorageBuffer_OutputBlock = OpTypePointer StorageBuffer %OutputBlock\n"
634         << "                  %outputBuffer = OpVariable %_ptr_StorageBuffer_OutputBlock StorageBuffer\n"
635         << "       %_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint\n"
636         << "                   %CalleeBlock = OpTypeStruct %uint\n"
637         << "%_ptr_StorageBuffer_CalleeBlock = OpTypePointer StorageBuffer %CalleeBlock\n"
638         << "                  %calleeBuffer = OpVariable %_ptr_StorageBuffer_CalleeBlock StorageBuffer\n"
639         << "                       %as_type = OpTypeAccelerationStructureKHR\n"
640         << "        %as_uniformconstant_ptr = OpTypePointer UniformConstant %as_type\n"
641         << "                    %topLevelAS = OpVariable %as_uniformconstant_ptr UniformConstant\n"
642         << "${EXTRA_BINDING_VARIABLES}"
643         << "                          %main = OpFunction %void None %void_func\n"
644         << "                    %main_label = OpLabel\n"
645         << "${EXTRA_FUNCTION_VARIABLES}"
646         << "                 %input_val_ptr = OpAccessChain %data_storagebuffer_ptr %inputBuffer %int_0\n"
647         << "                %output_val_ptr = OpAccessChain %_ptr_StorageBuffer_uint %outputBuffer %int_0\n"
648         // Note we use Volatile to load the input buffer value before and after the call statements.
649         << "              %input_val_before = OpLoad ${INPUT_BUFFER_VALUE_TYPE} %input_val_ptr Volatile\n"
650         << "${CALC_ZERO_FOR_CALLABLE}"
651         << "${CALL_STATEMENTS}"
652         << "               %input_val_after = OpLoad ${INPUT_BUFFER_VALUE_TYPE} %input_val_ptr Volatile\n"
653         << "${CALC_EQUAL_STATEMENT}"
654         << "                    %output_val = OpSelect %uint %equal %uint_1 %uint_0\n"
655         << "                                  OpStore %output_val_ptr %output_val\n"
656         << "                                  OpReturn\n"
657         << "                                  OpFunctionEnd\n";
658 
659     const tcu::StringTemplate spvTemplate(spvTemplateStream.str());
660 
661     std::map<std::string, std::string> subs;
662     std::string componentTypeName;
663     std::string opEqual;
664     const int numComponents     = static_cast<int>(m_params.vectorType);
665     const auto isArray          = (numComponents > static_cast<int>(VectorType::V4));
666     const auto numComponentsStr = de::toString(numComponents);
667 
668     subs["EXTRA_CAPABILITIES"]        = "";
669     subs["EXTRA_EXTENSIONS"]          = "";
670     subs["EXTRA_TYPES_AND_CONSTANTS"] = "";
671     subs["EXTRA_FUNCTION_VARIABLES"]  = "";
672     subs["EXTRA_BINDINGS"]            = "";
673     subs["EXTRA_BINDING_VARIABLES"]   = "";
674     subs["EXTRA_FUNCTIONS"]           = "";
675 
676     // Take into account some of these substitutions will be updated after the if-block.
677 
678     if (m_params.dataType == DataType::INT32)
679     {
680         componentTypeName = "int";
681 
682         subs["INPUT_BUFFER_VALUE_TYPE"] = "%int";
683         subs["EXTRA_TYPES_AND_CONSTANTS"] += "                        %int_37 = OpConstant %int 37\n";
684         subs["CALC_ZERO_FOR_CALLABLE"] = "                      %zero_int = OpISub %int %input_val_before %int_37\n"
685                                          "             %zero_for_callable = OpBitcast %uint %zero_int\n";
686     }
687     else if (m_params.dataType == DataType::UINT32)
688     {
689         componentTypeName = "uint";
690 
691         subs["INPUT_BUFFER_VALUE_TYPE"] = "%uint";
692         subs["EXTRA_TYPES_AND_CONSTANTS"] += "                       %uint_37 = OpConstant %uint 37\n";
693         subs["CALC_ZERO_FOR_CALLABLE"] = "             %zero_for_callable = OpISub %uint %input_val_before %uint_37\n";
694     }
695     else if (m_params.dataType == DataType::INT64)
696     {
697         componentTypeName = "long";
698 
699         subs["EXTRA_CAPABILITIES"] += "                                  OpCapability Int64\n";
700         subs["INPUT_BUFFER_VALUE_TYPE"] = "%long";
701         subs["EXTRA_TYPES_AND_CONSTANTS"] += "                          %long = OpTypeInt 64 1\n"
702                                              "                       %long_37 = OpConstant %long 37\n";
703         subs["CALC_ZERO_FOR_CALLABLE"] = "                     %zero_long = OpISub %long %input_val_before %long_37\n"
704                                          "             %zero_for_callable = OpSConvert %uint %zero_long\n";
705     }
706     else if (m_params.dataType == DataType::UINT64)
707     {
708         componentTypeName = "ulong";
709 
710         subs["EXTRA_CAPABILITIES"] += "                                  OpCapability Int64\n";
711         subs["INPUT_BUFFER_VALUE_TYPE"] = "%ulong";
712         subs["EXTRA_TYPES_AND_CONSTANTS"] += "                         %ulong = OpTypeInt 64 0\n"
713                                              "                      %ulong_37 = OpConstant %ulong 37\n";
714         subs["CALC_ZERO_FOR_CALLABLE"] = "                    %zero_ulong = OpISub %ulong %input_val_before %ulong_37\n"
715                                          "             %zero_for_callable = OpUConvert %uint %zero_ulong\n";
716     }
717     else if (m_params.dataType == DataType::INT16)
718     {
719         componentTypeName = "short";
720 
721         subs["EXTRA_CAPABILITIES"] += "                                  OpCapability Int16\n"
722                                       "                                  OpCapability StorageBuffer16BitAccess\n";
723         subs["EXTRA_EXTENSIONS"] += "                                  OpExtension \"SPV_KHR_16bit_storage\"\n";
724         subs["INPUT_BUFFER_VALUE_TYPE"] = "%short";
725         subs["EXTRA_TYPES_AND_CONSTANTS"] += "                         %short = OpTypeInt 16 1\n"
726                                              "                      %short_37 = OpConstant %short 37\n";
727         subs["CALC_ZERO_FOR_CALLABLE"] = "                    %zero_short = OpISub %short %input_val_before %short_37\n"
728                                          "             %zero_for_callable = OpSConvert %uint %zero_short\n";
729     }
730     else if (m_params.dataType == DataType::UINT16)
731     {
732         componentTypeName = "ushort";
733 
734         subs["EXTRA_CAPABILITIES"] += "                                  OpCapability Int16\n"
735                                       "                                  OpCapability StorageBuffer16BitAccess\n";
736         subs["EXTRA_EXTENSIONS"] += "                                  OpExtension \"SPV_KHR_16bit_storage\"\n";
737         subs["INPUT_BUFFER_VALUE_TYPE"] = "%ushort";
738         subs["EXTRA_TYPES_AND_CONSTANTS"] += "                        %ushort = OpTypeInt 16 0\n"
739                                              "                     %ushort_37 = OpConstant %ushort 37\n";
740         subs["CALC_ZERO_FOR_CALLABLE"] =
741             "                   %zero_ushort = OpISub %ushort %input_val_before %ushort_37\n"
742             "             %zero_for_callable = OpUConvert %uint %zero_ushort\n";
743     }
744     else if (m_params.dataType == DataType::INT8)
745     {
746         componentTypeName = "char";
747 
748         subs["EXTRA_CAPABILITIES"] += "                                  OpCapability Int8\n"
749                                       "                                  OpCapability StorageBuffer8BitAccess\n";
750         subs["EXTRA_EXTENSIONS"] += "                                  OpExtension \"SPV_KHR_8bit_storage\"\n";
751         subs["INPUT_BUFFER_VALUE_TYPE"] = "%char";
752         subs["EXTRA_TYPES_AND_CONSTANTS"] += "                          %char = OpTypeInt 8 1\n"
753                                              "                       %char_37 = OpConstant %char 37\n";
754         subs["CALC_ZERO_FOR_CALLABLE"] = "                     %zero_char = OpISub %char %input_val_before %char_37\n"
755                                          "             %zero_for_callable = OpSConvert %uint %zero_char\n";
756     }
757     else if (m_params.dataType == DataType::UINT8)
758     {
759         componentTypeName = "uchar";
760 
761         subs["EXTRA_CAPABILITIES"] += "                                  OpCapability Int8\n"
762                                       "                                  OpCapability StorageBuffer8BitAccess\n";
763         subs["EXTRA_EXTENSIONS"] += "                                  OpExtension \"SPV_KHR_8bit_storage\"\n";
764         subs["INPUT_BUFFER_VALUE_TYPE"] = "%uchar";
765         subs["EXTRA_TYPES_AND_CONSTANTS"] += "                         %uchar = OpTypeInt 8 0\n"
766                                              "                      %uchar_37 = OpConstant %uchar 37\n";
767         subs["CALC_ZERO_FOR_CALLABLE"] = "                    %zero_uchar = OpISub %uchar %input_val_before %uchar_37\n"
768                                          "             %zero_for_callable = OpUConvert %uint %zero_uchar\n";
769     }
770     else if (m_params.dataType == DataType::FLOAT32)
771     {
772         componentTypeName = "float";
773 
774         subs["INPUT_BUFFER_VALUE_TYPE"] = "%float";
775         subs["EXTRA_TYPES_AND_CONSTANTS"] += "                      %float_37 = OpConstant %float 37\n";
776         subs["CALC_ZERO_FOR_CALLABLE"] = "                    %zero_float = OpFSub %float %input_val_before %float_37\n"
777                                          "             %zero_for_callable = OpConvertFToU %uint %zero_float\n";
778     }
779     else if (m_params.dataType == DataType::FLOAT64)
780     {
781         componentTypeName = "double";
782 
783         subs["EXTRA_CAPABILITIES"] += "                                  OpCapability Float64\n";
784         subs["INPUT_BUFFER_VALUE_TYPE"] = "%double";
785         subs["EXTRA_TYPES_AND_CONSTANTS"] += "                        %double = OpTypeFloat 64\n"
786                                              "                     %double_37 = OpConstant %double 37\n";
787         subs["CALC_ZERO_FOR_CALLABLE"] =
788             "                   %zero_double = OpFSub %double %input_val_before %double_37\n"
789             "             %zero_for_callable = OpConvertFToU %uint %zero_double\n";
790     }
791     else if (m_params.dataType == DataType::FLOAT16)
792     {
793         componentTypeName = "half";
794 
795         subs["EXTRA_CAPABILITIES"] += "                                  OpCapability Float16\n"
796                                       "                                  OpCapability StorageBuffer16BitAccess\n";
797         subs["EXTRA_EXTENSIONS"] += "                                  OpExtension \"SPV_KHR_16bit_storage\"\n";
798         subs["INPUT_BUFFER_VALUE_TYPE"] = "%half";
799         subs["EXTRA_TYPES_AND_CONSTANTS"] += "                          %half = OpTypeFloat 16\n"
800                                              "                       %half_37 = OpConstant %half 37\n";
801         subs["CALC_ZERO_FOR_CALLABLE"] = "                     %zero_half = OpFSub %half %input_val_before %half_37\n"
802                                          "             %zero_for_callable = OpConvertFToU %uint %zero_half\n";
803     }
804     else if (m_params.dataType == DataType::STRUCT)
805     {
806         componentTypeName = "InputStruct";
807 
808         subs["INPUT_BUFFER_VALUE_TYPE"] = "%InputStruct";
809         subs["EXTRA_TYPES_AND_CONSTANTS"] += "                   %InputStruct = OpTypeStruct %uint %float\n"
810                                              "                      %float_37 = OpConstant %float 37\n"
811                                              "            %uint_part_ptr_type = OpTypePointer StorageBuffer %uint\n"
812                                              "           %float_part_ptr_type = OpTypePointer StorageBuffer %float\n"
813                                              "       %uint_part_func_ptr_type = OpTypePointer Function %uint\n"
814                                              "      %float_part_func_ptr_type = OpTypePointer Function %float\n"
815                                              "    %input_struct_func_ptr_type = OpTypePointer Function %InputStruct\n";
816         subs["INTERFACE_DECORATIONS"] = "                                  OpMemberDecorate %InputStruct 0 Offset 0\n"
817                                         "                                  OpMemberDecorate %InputStruct 1 Offset 4\n";
818 
819         // Sum struct members, then substract constant and convert to uint.
820         subs["CALC_ZERO_FOR_CALLABLE"] =
821             "                 %uint_part_ptr = OpAccessChain %uint_part_ptr_type %input_val_ptr %uint_0\n"
822             "                %float_part_ptr = OpAccessChain %float_part_ptr_type %input_val_ptr %uint_1\n"
823             "                     %uint_part = OpLoad %uint %uint_part_ptr\n"
824             "                    %float_part = OpLoad %float %float_part_ptr\n"
825             "                 %uint_as_float = OpConvertUToF %float %uint_part\n"
826             "                    %member_sum = OpFAdd %float %float_part %uint_as_float\n"
827             "                    %zero_float = OpFSub %float %member_sum %float_37\n"
828             "             %zero_for_callable = OpConvertFToU %uint %zero_float\n";
829     }
830     else if (samplersNeeded(m_params.dataType))
831     {
832         // These tests will use additional bindings as arrays of 2 elements:
833         // - 1 array of samplers.
834         // - 1 array of images.
835         // - 1 array of combined image samplers.
836         // Input values are typically used as texture coordinates (normally zeros)
837         // Pixels will contain the expected values instead of them being in the input buffer.
838 
839         subs["INPUT_BUFFER_VALUE_TYPE"] = "%float";
840         subs["EXTRA_CAPABILITIES"] +=
841             "                                  OpCapability SampledImageArrayNonUniformIndexing\n";
842         subs["EXTRA_EXTENSIONS"] += "                                  OpExtension \"SPV_EXT_descriptor_indexing\"\n";
843         subs["MAIN_INTERFACE_EXTRAS"] += " %sampledTexture %textureSampler %combinedImageSampler";
844         subs["EXTRA_BINDINGS"] += "                                  OpDecorate %sampledTexture DescriptorSet 0\n"
845                                   "                                  OpDecorate %sampledTexture Binding 4\n"
846                                   "                                  OpDecorate %textureSampler DescriptorSet 0\n"
847                                   "                                  OpDecorate %textureSampler Binding 5\n"
848                                   "                                  OpDecorate %combinedImageSampler DescriptorSet 0\n"
849                                   "                                  OpDecorate %combinedImageSampler Binding 6\n";
850         subs["EXTRA_TYPES_AND_CONSTANTS"] +=
851             "                       %uint_37 = OpConstant %uint 37\n"
852             "                        %v4uint = OpTypeVector %uint 4\n"
853             "                       %v2float = OpTypeVector %float 2\n"
854             "                    %image_type = OpTypeImage %uint 2D 0 0 0 1 Unknown\n"
855             "              %image_array_type = OpTypeArray %image_type %uint_2\n"
856             "  %image_array_type_uniform_ptr = OpTypePointer UniformConstant %image_array_type\n"
857             "        %image_type_uniform_ptr = OpTypePointer UniformConstant %image_type\n"
858             "                  %sampler_type = OpTypeSampler\n"
859             "            %sampler_array_type = OpTypeArray %sampler_type %uint_2\n"
860             "%sampler_array_type_uniform_ptr = OpTypePointer UniformConstant %sampler_array_type\n"
861             "      %sampler_type_uniform_ptr = OpTypePointer UniformConstant %sampler_type\n"
862             "            %sampled_image_type = OpTypeSampledImage %image_type\n"
863             "      %sampled_image_array_type = OpTypeArray %sampled_image_type %uint_2\n"
864             "%sampled_image_array_type_uniform_ptr = OpTypePointer UniformConstant %sampled_image_array_type\n"
865             "%sampled_image_type_uniform_ptr = OpTypePointer UniformConstant %sampled_image_type\n";
866         subs["EXTRA_BINDING_VARIABLES"] +=
867             "                %sampledTexture = OpVariable %image_array_type_uniform_ptr UniformConstant\n"
868             "                %textureSampler = OpVariable %sampler_array_type_uniform_ptr UniformConstant\n"
869             "          %combinedImageSampler = OpVariable %sampled_image_array_type_uniform_ptr UniformConstant\n";
870 
871         if (m_params.dataType == DataType::IMAGE || m_params.dataType == DataType::SAMPLER)
872         {
873             // Use the first sampler and sample from the first image.
874             subs["CALC_ZERO_FOR_CALLABLE"] +=
875                 "%image_0_ptr = OpAccessChain %image_type_uniform_ptr %sampledTexture %uint_0\n"
876                 "%sampler_0_ptr = OpAccessChain %sampler_type_uniform_ptr %textureSampler %uint_0\n"
877                 "%sampler_0 = OpLoad %sampler_type %sampler_0_ptr\n"
878                 "%image_0 = OpLoad %image_type %image_0_ptr\n"
879                 "%sampled_image_0 = OpSampledImage %sampled_image_type %image_0 %sampler_0\n"
880                 "%texture_coords_0 = OpCompositeConstruct %v2float %input_val_before %input_val_before\n"
881                 "%pixel_vec_0 = OpImageSampleExplicitLod %v4uint %sampled_image_0 %texture_coords_0 Lod|ZeroExtend "
882                 "%float_0\n"
883                 "%pixel_0 = OpCompositeExtract %uint %pixel_vec_0 0\n"
884                 "%zero_for_callable = OpISub %uint %pixel_0 %uint_37\n";
885         }
886         else if (m_params.dataType == DataType::SAMPLED_IMAGE)
887         {
888             // Use the first combined image sampler.
889             subs["CALC_ZERO_FOR_CALLABLE"] +=
890                 "%sampled_image_0_ptr = OpAccessChain %sampled_image_type_uniform_ptr %combinedImageSampler %uint_0\n"
891                 "%sampled_image_0 = OpLoad %sampled_image_type %sampled_image_0_ptr\n"
892                 "%texture_coords_0 = OpCompositeConstruct %v2float %input_val_before %input_val_before\n"
893                 "%pixel_vec_0 = OpImageSampleExplicitLod %v4uint %sampled_image_0 %texture_coords_0 Lod|ZeroExtend "
894                 "%float_0\n"
895                 "%pixel_0 = OpCompositeExtract %uint %pixel_vec_0 0\n"
896                 "%zero_for_callable = OpISub %uint %pixel_0 %uint_37\n";
897         }
898         else if (m_params.dataType == DataType::PTR_IMAGE)
899         {
900             // We attempt to create the second pointer before the call.
901             subs["CALC_ZERO_FOR_CALLABLE"] +=
902                 "%image_0_ptr = OpAccessChain %image_type_uniform_ptr %sampledTexture %uint_0\n"
903                 "%image_1_ptr = OpAccessChain %image_type_uniform_ptr %sampledTexture %uint_1\n"
904                 "%image_0 = OpLoad %image_type %image_0_ptr\n"
905                 "%sampler_0_ptr = OpAccessChain %sampler_type_uniform_ptr %textureSampler %uint_0\n"
906                 "%sampler_0 = OpLoad %sampler_type %sampler_0_ptr\n"
907                 "%sampled_image_0 = OpSampledImage %sampled_image_type %image_0 %sampler_0\n"
908                 "%texture_coords_0 = OpCompositeConstruct %v2float %input_val_before %input_val_before\n"
909                 "%pixel_vec_0 = OpImageSampleExplicitLod %v4uint %sampled_image_0 %texture_coords_0 Lod|ZeroExtend "
910                 "%float_0\n"
911                 "%pixel_0 = OpCompositeExtract %uint %pixel_vec_0 0\n"
912                 "%zero_for_callable = OpISub %uint %pixel_0 %uint_37\n";
913         }
914         else if (m_params.dataType == DataType::PTR_SAMPLER)
915         {
916             // We attempt to create the second pointer before the call.
917             subs["CALC_ZERO_FOR_CALLABLE"] +=
918                 "%sampler_0_ptr = OpAccessChain %sampler_type_uniform_ptr %textureSampler %uint_0\n"
919                 "%sampler_1_ptr = OpAccessChain %sampler_type_uniform_ptr %textureSampler %uint_1\n"
920                 "%sampler_0 = OpLoad %sampler_type %sampler_0_ptr\n"
921                 "%image_0_ptr = OpAccessChain %image_type_uniform_ptr %sampledTexture %uint_0\n"
922                 "%image_0 = OpLoad %image_type %image_0_ptr\n"
923                 "%sampled_image_0 = OpSampledImage %sampled_image_type %image_0 %sampler_0\n"
924                 "%texture_coords_0 = OpCompositeConstruct %v2float %input_val_before %input_val_before\n"
925                 "%pixel_vec_0 = OpImageSampleExplicitLod %v4uint %sampled_image_0 %texture_coords_0 Lod|ZeroExtend "
926                 "%float_0\n"
927                 "%pixel_0 = OpCompositeExtract %uint %pixel_vec_0 0\n"
928                 "%zero_for_callable = OpISub %uint %pixel_0 %uint_37\n";
929         }
930         else if (m_params.dataType == DataType::PTR_SAMPLED_IMAGE)
931         {
932             // We attempt to create the second pointer before the call.
933             subs["CALC_ZERO_FOR_CALLABLE"] +=
934                 "%sampled_image_0_ptr = OpAccessChain %sampled_image_type_uniform_ptr %combinedImageSampler %uint_0\n"
935                 "%sampled_image_1_ptr = OpAccessChain %sampled_image_type_uniform_ptr %combinedImageSampler %uint_1\n"
936                 "%sampled_image_0 = OpLoad %sampled_image_type %sampled_image_0_ptr\n"
937                 "%texture_coords_0 = OpCompositeConstruct %v2float %input_val_before %input_val_before\n"
938                 "%pixel_vec_0 = OpImageSampleExplicitLod %v4uint %sampled_image_0 %texture_coords_0 Lod|ZeroExtend "
939                 "%float_0\n"
940                 "%pixel_0 = OpCompositeExtract %uint %pixel_vec_0 0\n"
941                 "%zero_for_callable = OpISub %uint %pixel_0 %uint_37\n";
942         }
943         else
944         {
945             DE_ASSERT(false);
946         }
947     }
948     else if (storageImageNeeded(m_params.dataType))
949     {
950         subs["INPUT_BUFFER_VALUE_TYPE"] = "%int";
951         subs["MAIN_INTERFACE_EXTRAS"] += " %storageImage";
952         subs["EXTRA_BINDINGS"] += "                                  OpDecorate %storageImage DescriptorSet 0\n"
953                                   "                                  OpDecorate %storageImage Binding 4\n";
954         subs["EXTRA_TYPES_AND_CONSTANTS"] +=
955             "                       %uint_37 = OpConstant %uint 37\n"
956             "                         %v2int = OpTypeVector %int 2\n"
957             "                    %image_type = OpTypeImage %uint 2D 0 0 0 2 R32ui\n"
958             "        %image_type_uniform_ptr = OpTypePointer UniformConstant %image_type\n"
959             "                  %uint_img_ptr = OpTypePointer Image %uint\n";
960         subs["EXTRA_BINDING_VARIABLES"] +=
961             "                  %storageImage = OpVariable %image_type_uniform_ptr UniformConstant\n";
962 
963         // Load value from the image, expecting it to be 37 and swapping it with 5.
964         subs["CALC_ZERO_FOR_CALLABLE"] +=
965             "%coords = OpCompositeConstruct %v2int %input_val_before %input_val_before\n"
966             "%texel_ptr = OpImageTexelPointer %uint_img_ptr %storageImage %coords %uint_0\n"
967             "%texel_value = OpAtomicCompareExchange %uint %texel_ptr %uint_1 %uint_0 %uint_0 %uint_5 %uint_37\n"
968             "%zero_for_callable = OpISub %uint %texel_value %uint_37\n";
969     }
970     else if (m_params.dataType == DataType::OP_NULL)
971     {
972         subs["INPUT_BUFFER_VALUE_TYPE"] = "%uint";
973         subs["EXTRA_TYPES_AND_CONSTANTS"] += "                       %uint_37 = OpConstant %uint 37\n"
974                                              "                 %constant_null = OpConstantNull %uint\n";
975 
976         // Create a local copy of the null constant global object to work with it.
977         subs["CALC_ZERO_FOR_CALLABLE"] +=
978             "%constant_null_copy = OpCopyObject %uint %constant_null\n"
979             "%is_37_before = OpIEqual %bool %input_val_before %uint_37\n"
980             "%zero_for_callable = OpSelect %uint %is_37_before %constant_null_copy %uint_5\n";
981     }
982     else if (m_params.dataType == DataType::OP_UNDEF)
983     {
984         subs["INPUT_BUFFER_VALUE_TYPE"] = "%uint";
985         subs["EXTRA_TYPES_AND_CONSTANTS"] += "                       %uint_37 = OpConstant %uint 37\n";
986 
987         // Extract an undef value and write it to the output buffer to make sure it's used before the call. The value will be overwritten later.
988         subs["CALC_ZERO_FOR_CALLABLE"] += "%undef_var = OpUndef %uint\n"
989                                           "%undef_val_before = OpCopyObject %uint %undef_var\n"
990                                           "OpStore %output_val_ptr %undef_val_before Volatile\n"
991                                           "%zero_for_callable = OpISub %uint %uint_37 %input_val_before\n";
992     }
993     else
994     {
995         DE_ASSERT(false);
996     }
997 
998     // Comparison statement for data before and after the call.
999     switch (m_params.dataType)
1000     {
1001     case DataType::INT32:
1002     case DataType::UINT32:
1003     case DataType::INT64:
1004     case DataType::UINT64:
1005     case DataType::INT16:
1006     case DataType::UINT16:
1007     case DataType::INT8:
1008     case DataType::UINT8:
1009         opEqual = "OpIEqual";
1010         break;
1011     case DataType::FLOAT32:
1012     case DataType::FLOAT64:
1013     case DataType::FLOAT16:
1014         opEqual = "OpFOrdEqual";
1015         break;
1016     case DataType::STRUCT:
1017     case DataType::IMAGE:
1018     case DataType::SAMPLER:
1019     case DataType::SAMPLED_IMAGE:
1020     case DataType::PTR_IMAGE:
1021     case DataType::PTR_SAMPLER:
1022     case DataType::PTR_SAMPLED_IMAGE:
1023     case DataType::PTR_TEXEL:
1024     case DataType::OP_NULL:
1025     case DataType::OP_UNDEF:
1026         // These needs special code for the comparison.
1027         opEqual = "INVALID";
1028         break;
1029     default:
1030         DE_ASSERT(false);
1031         break;
1032     }
1033 
1034     if (m_params.dataType == DataType::STRUCT)
1035     {
1036         // We need to store the before and after values in a variable in order to be able to access each member individually without accessing the StorageBuffer again.
1037         subs["EXTRA_FUNCTION_VARIABLES"] =
1038             "         %input_val_func_before = OpVariable %input_struct_func_ptr_type Function\n"
1039             "          %input_val_func_after = OpVariable %input_struct_func_ptr_type Function\n";
1040         subs["CALC_EQUAL_STATEMENT"] =
1041             "                                  OpStore %input_val_func_before %input_val_before\n"
1042             "                                  OpStore %input_val_func_after %input_val_after\n"
1043             "     %uint_part_func_before_ptr = OpAccessChain %uint_part_func_ptr_type %input_val_func_before %uint_0\n"
1044             "    %float_part_func_before_ptr = OpAccessChain %float_part_func_ptr_type %input_val_func_before %uint_1\n"
1045             "      %uint_part_func_after_ptr = OpAccessChain %uint_part_func_ptr_type %input_val_func_after %uint_0\n"
1046             "     %float_part_func_after_ptr = OpAccessChain %float_part_func_ptr_type %input_val_func_after %uint_1\n"
1047             "              %uint_part_before = OpLoad %uint %uint_part_func_before_ptr\n"
1048             "             %float_part_before = OpLoad %float %float_part_func_before_ptr\n"
1049             "               %uint_part_after = OpLoad %uint %uint_part_func_after_ptr\n"
1050             "              %float_part_after = OpLoad %float %float_part_func_after_ptr\n"
1051             "                    %uint_equal = OpIEqual %bool %uint_part_before %uint_part_after\n"
1052             "                   %float_equal = OpFOrdEqual %bool %float_part_before %float_part_after\n"
1053             "                         %equal = OpLogicalAnd %bool %uint_equal %float_equal\n";
1054     }
1055     else if (m_params.dataType == DataType::IMAGE)
1056     {
1057         // Use the same image and the second sampler with different coordinates (actually the same).
1058         subs["CALC_EQUAL_STATEMENT"] +=
1059             "%sampler_1_ptr = OpAccessChain %sampler_type_uniform_ptr %textureSampler %uint_1\n"
1060             "%sampler_1 = OpLoad %sampler_type %sampler_1_ptr\n"
1061             "%sampled_image_1 = OpSampledImage %sampled_image_type %image_0 %sampler_1\n"
1062             "%texture_coords_1 = OpCompositeConstruct %v2float %input_val_after %input_val_after\n"
1063             "%pixel_vec_1 = OpImageSampleExplicitLod %v4uint %sampled_image_1 %texture_coords_1 Lod|ZeroExtend "
1064             "%float_0\n"
1065             "%pixel_1 = OpCompositeExtract %uint %pixel_vec_1 0\n"
1066             "%equal = OpIEqual %bool %pixel_0 %pixel_1\n";
1067     }
1068     else if (m_params.dataType == DataType::SAMPLER)
1069     {
1070         // Use the same sampler and sample from the second image with different coordinates (but actually the same).
1071         subs["CALC_EQUAL_STATEMENT"] +=
1072             "%image_1_ptr = OpAccessChain %image_type_uniform_ptr %sampledTexture %uint_1\n"
1073             "%image_1 = OpLoad %image_type %image_1_ptr\n"
1074             "%sampled_image_1 = OpSampledImage %sampled_image_type %image_1 %sampler_0\n"
1075             "%texture_coords_1 = OpCompositeConstruct %v2float %input_val_after %input_val_after\n"
1076             "%pixel_vec_1 = OpImageSampleExplicitLod %v4uint %sampled_image_1 %texture_coords_1 Lod|ZeroExtend "
1077             "%float_0\n"
1078             "%pixel_1 = OpCompositeExtract %uint %pixel_vec_1 0\n"
1079             "%equal = OpIEqual %bool %pixel_0 %pixel_1\n";
1080     }
1081     else if (m_params.dataType == DataType::SAMPLED_IMAGE)
1082     {
1083         // Reuse the same combined image sampler with different coordinates (actually the same).
1084         subs["CALC_EQUAL_STATEMENT"] +=
1085             "%texture_coords_1 = OpCompositeConstruct %v2float %input_val_after %input_val_after\n"
1086             "%pixel_vec_1 = OpImageSampleExplicitLod %v4uint %sampled_image_0 %texture_coords_1 Lod|ZeroExtend "
1087             "%float_0\n"
1088             "%pixel_1 = OpCompositeExtract %uint %pixel_vec_1 0\n"
1089             "%equal = OpIEqual %bool %pixel_0 %pixel_1\n";
1090     }
1091     else if (m_params.dataType == DataType::PTR_IMAGE)
1092     {
1093         // We attempt to use the second pointer only after the call.
1094         subs["CALC_EQUAL_STATEMENT"] +=
1095             "%image_1 = OpLoad %image_type %image_1_ptr\n"
1096             "%sampled_image_1 = OpSampledImage %sampled_image_type %image_1 %sampler_0\n"
1097             "%texture_coords_1 = OpCompositeConstruct %v2float %input_val_after %input_val_after\n"
1098             "%pixel_vec_1 = OpImageSampleExplicitLod %v4uint %sampled_image_1 %texture_coords_1 Lod|ZeroExtend "
1099             "%float_0\n"
1100             "%pixel_1 = OpCompositeExtract %uint %pixel_vec_1 0\n"
1101             "%equal = OpIEqual %bool %pixel_0 %pixel_1\n";
1102     }
1103     else if (m_params.dataType == DataType::PTR_SAMPLER)
1104     {
1105         // We attempt to use the second pointer only after the call.
1106         subs["CALC_EQUAL_STATEMENT"] +=
1107             "%sampler_1 = OpLoad %sampler_type %sampler_1_ptr\n"
1108             "%sampled_image_1 = OpSampledImage %sampled_image_type %image_0 %sampler_1\n"
1109             "%texture_coords_1 = OpCompositeConstruct %v2float %input_val_after %input_val_after\n"
1110             "%pixel_vec_1 = OpImageSampleExplicitLod %v4uint %sampled_image_1 %texture_coords_1 Lod|ZeroExtend "
1111             "%float_0\n"
1112             "%pixel_1 = OpCompositeExtract %uint %pixel_vec_1 0\n"
1113             "%equal = OpIEqual %bool %pixel_0 %pixel_1\n";
1114     }
1115     else if (m_params.dataType == DataType::PTR_SAMPLED_IMAGE)
1116     {
1117         // We attempt to use the second pointer only after the call.
1118         subs["CALC_EQUAL_STATEMENT"] +=
1119             "%sampled_image_1 = OpLoad %sampled_image_type %sampled_image_1_ptr\n"
1120             "%texture_coords_1 = OpCompositeConstruct %v2float %input_val_after %input_val_after\n"
1121             "%pixel_vec_1 = OpImageSampleExplicitLod %v4uint %sampled_image_1 %texture_coords_1 Lod|ZeroExtend "
1122             "%float_0\n"
1123             "%pixel_1 = OpCompositeExtract %uint %pixel_vec_1 0\n"
1124             "%equal = OpIEqual %bool %pixel_0 %pixel_1\n";
1125     }
1126     else if (m_params.dataType == DataType::PTR_TEXEL)
1127     {
1128         // Check value 5 was stored properly.
1129         subs["CALC_EQUAL_STATEMENT"] += "%stored_val = OpAtomicLoad %uint %texel_ptr %uint_1 %uint_0\n"
1130                                         "%equal = OpIEqual %bool %stored_val %uint_5\n";
1131     }
1132     else if (m_params.dataType == DataType::OP_NULL)
1133     {
1134         // Reuse the null constant after the call.
1135         subs["CALC_EQUAL_STATEMENT"] += "%is_37_after = OpIEqual %bool %input_val_after %uint_37\n"
1136                                         "%writeback_val = OpSelect %uint %is_37_after %constant_null_copy %uint_5\n"
1137                                         "OpStore %input_val_ptr %writeback_val Volatile\n"
1138                                         "%readback_val = OpLoad %uint %input_val_ptr Volatile\n"
1139                                         "%equal = OpIEqual %bool %readback_val %uint_0\n";
1140     }
1141     else if (m_params.dataType == DataType::OP_UNDEF)
1142     {
1143         // Extract another undef value and write it to the input buffer. It will not be checked later.
1144         subs["CALC_EQUAL_STATEMENT"] += "%undef_val_after = OpCopyObject %uint %undef_var\n"
1145                                         "OpStore %input_val_ptr %undef_val_after Volatile\n"
1146                                         "%equal = OpIEqual %bool %input_val_after %input_val_before\n";
1147     }
1148     else
1149     {
1150         subs["CALC_EQUAL_STATEMENT"] +=
1151             "                         %equal = " + opEqual + " %bool %input_val_before %input_val_after\n";
1152     }
1153 
1154     // Modifications for vectors and arrays.
1155     if (numComponents > 1)
1156     {
1157         const std::string vectorTypeName    = "v" + numComponentsStr + componentTypeName;
1158         const std::string opType            = (isArray ? "OpTypeArray" : "OpTypeVector");
1159         const std::string componentCountStr = (isArray ? ("%uint_" + numComponentsStr) : numComponentsStr);
1160 
1161         // Some extra types are needed.
1162         if (!(m_params.dataType == DataType::FLOAT32 && m_params.vectorType == VectorType::V3))
1163         {
1164             // Note: v3float is already defined in the shader by default.
1165             subs["EXTRA_TYPES_AND_CONSTANTS"] +=
1166                 "%" + vectorTypeName + " = " + opType + " %" + componentTypeName + " " + componentCountStr + "\n";
1167         }
1168         subs["EXTRA_TYPES_AND_CONSTANTS"] +=
1169             "%v" + numComponentsStr + "bool = " + opType + " %bool " + componentCountStr + "\n";
1170         subs["EXTRA_TYPES_AND_CONSTANTS"] += "%comp_ptr = OpTypePointer StorageBuffer %" + componentTypeName + "\n";
1171 
1172         // The input value in the buffer has a different type.
1173         subs["INPUT_BUFFER_VALUE_TYPE"] = "%" + vectorTypeName;
1174 
1175         // Overwrite the way we calculate the zero used in the call.
1176 
1177         // Proper operations for adding, substracting and converting components.
1178         std::string opAdd;
1179         std::string opSub;
1180         std::string opConvert;
1181 
1182         switch (m_params.dataType)
1183         {
1184         case DataType::INT32:
1185         case DataType::UINT32:
1186         case DataType::INT64:
1187         case DataType::UINT64:
1188         case DataType::INT16:
1189         case DataType::UINT16:
1190         case DataType::INT8:
1191         case DataType::UINT8:
1192             opAdd = "OpIAdd";
1193             opSub = "OpISub";
1194             break;
1195         case DataType::FLOAT32:
1196         case DataType::FLOAT64:
1197         case DataType::FLOAT16:
1198             opAdd = "OpFAdd";
1199             opSub = "OpFSub";
1200             break;
1201         default:
1202             DE_ASSERT(false);
1203             break;
1204         }
1205 
1206         switch (m_params.dataType)
1207         {
1208         case DataType::UINT32:
1209             opConvert = "OpCopyObject";
1210             break;
1211         case DataType::INT32:
1212             opConvert = "OpBitcast";
1213             break;
1214         case DataType::INT64:
1215         case DataType::INT16:
1216         case DataType::INT8:
1217             opConvert = "OpSConvert";
1218             break;
1219         case DataType::UINT64:
1220         case DataType::UINT16:
1221         case DataType::UINT8:
1222             opConvert = "OpUConvert";
1223             break;
1224         case DataType::FLOAT32:
1225         case DataType::FLOAT64:
1226         case DataType::FLOAT16:
1227             opConvert = "OpConvertFToU";
1228             break;
1229         default:
1230             DE_ASSERT(false);
1231             break;
1232         }
1233 
1234         std::ostringstream zeroForCallable;
1235 
1236         // Create pointers to components and load components.
1237         for (int i = 0; i < numComponents; ++i)
1238         {
1239             zeroForCallable << "%component_ptr_" << i << " = OpAccessChain %comp_ptr %input_val_ptr %uint_" << i << "\n"
1240                             << "%component_" << i << " = OpLoad %" << componentTypeName << " %component_ptr_" << i
1241                             << "\n";
1242         }
1243 
1244         // Sum components together in %total_sum.
1245         for (int i = 1; i < numComponents; ++i)
1246         {
1247             const std::string previous = ((i == 1) ? "%component_0" : ("%partial_" + de::toString(i - 1)));
1248             const std::string resultName =
1249                 ((i == (numComponents - 1)) ? "%total_sum" : ("%partial_" + de::toString(i)));
1250             zeroForCallable << resultName << " = " << opAdd << " %" << componentTypeName << " %component_" << i << " "
1251                             << previous << "\n";
1252         }
1253 
1254         // Recalculate the zero.
1255         zeroForCallable << "%zero_" << componentTypeName << " = " << opSub << " %" << componentTypeName
1256                         << " %total_sum %" << componentTypeName << "_37\n"
1257                         << "%zero_for_callable = " << opConvert << " %uint %zero_" << componentTypeName << "\n";
1258 
1259         // Finally replace the zero_for_callable statements with the special version for vectors.
1260         subs["CALC_ZERO_FOR_CALLABLE"] = zeroForCallable.str();
1261 
1262         // Rework comparison statements.
1263         if (isArray)
1264         {
1265             // Arrays need to be compared per-component.
1266             std::ostringstream calcEqual;
1267 
1268             for (int i = 0; i < numComponents; ++i)
1269             {
1270                 calcEqual << "%component_after_" << i << " = OpLoad %" << componentTypeName << " %component_ptr_" << i
1271                           << "\n"
1272                           << "%equal_" << i << " = " << opEqual << " %bool %component_" << i << " %component_after_"
1273                           << i << "\n";
1274                 if (i > 0)
1275                     calcEqual << "%and_" << i << " = OpLogicalAnd %bool %equal_" << (i - 1) << " %equal_" << i << "\n";
1276                 if (i == numComponents - 1)
1277                     calcEqual << "%equal = OpCopyObject %bool %and_" << i << "\n";
1278             }
1279 
1280             subs["CALC_EQUAL_STATEMENT"] = calcEqual.str();
1281         }
1282         else
1283         {
1284             // Vectors can be compared using a bool vector and OpAll.
1285             subs["CALC_EQUAL_STATEMENT"] = "                  %equal_vector = " + opEqual + " %v" + numComponentsStr +
1286                                            "bool %input_val_before %input_val_after\n";
1287             subs["CALC_EQUAL_STATEMENT"] += "                         %equal = OpAll %bool %equal_vector\n";
1288         }
1289     }
1290 
1291     if (isArray)
1292     {
1293         // Arrays need an ArrayStride decoration.
1294         std::ostringstream interfaceDecorations;
1295         interfaceDecorations << "OpDecorate %v" << numComponentsStr << componentTypeName << " ArrayStride "
1296                              << getElementSize(m_params.dataType, VectorType::SCALAR) << "\n";
1297         subs["INTERFACE_DECORATIONS"] = interfaceDecorations.str();
1298     }
1299 
1300     const auto inputBlockDecls = getGLSLInputValDecl(m_params.dataType, m_params.vectorType);
1301 
1302     std::ostringstream glslBindings;
1303     glslBindings << inputBlockDecls.first // Additional data types needed.
1304                  << "layout(set = 0, binding = 0) uniform accelerationStructureEXT topLevelAS;\n"
1305                  << "layout(set = 0, binding = 1) buffer CalleeBlock { uint val; } calleeBuffer;\n"
1306                  << "layout(set = 0, binding = 2) buffer OutputBlock { uint val; } outputBuffer;\n"
1307                  << "layout(set = 0, binding = 3) buffer InputBlock { " << inputBlockDecls.second
1308                  << " } inputBuffer;\n";
1309 
1310     if (samplersNeeded(m_params.dataType))
1311     {
1312         glslBindings << "layout(set = 0, binding = 4) uniform utexture2D sampledTexture[2];\n"
1313                      << "layout(set = 0, binding = 5) uniform sampler textureSampler[2];\n"
1314                      << "layout(set = 0, binding = 6) uniform usampler2D combinedImageSampler[2];\n";
1315     }
1316     else if (storageImageNeeded(m_params.dataType))
1317     {
1318         glslBindings << "layout(set = 0, binding = 4, r32ui) uniform uimage2D storageImage;\n";
1319     }
1320 
1321     const auto glslBindingsStr = glslBindings.str();
1322     const auto glslHeaderStr   = "#version 460 core\n"
1323                                  "#extension GL_EXT_ray_tracing : require\n"
1324                                  "#extension GL_EXT_shader_explicit_arithmetic_types : require\n";
1325 
1326     if (m_params.callType == CallType::TRACE_RAY)
1327     {
1328         subs["ENTRY_POINT"] = "RayGenerationKHR";
1329         subs["MAIN_INTERFACE_EXTRAS"] += " %hitValue";
1330         subs["INTERFACE_DECORATIONS"] += "                                  OpDecorate %hitValue Location 0\n";
1331         subs["INTERFACE_TYPES_AND_VARIABLES"] =
1332             "                   %payload_ptr = OpTypePointer RayPayloadKHR %v3float\n"
1333             "                      %hitValue = OpVariable %payload_ptr RayPayloadKHR\n";
1334         subs["CALL_STATEMENTS"] =
1335             "                      %as_value = OpLoad %as_type %topLevelAS\n"
1336             "                                  OpTraceRayKHR %as_value %uint_0 %uint_255 %zero_for_callable "
1337             "%zero_for_callable %zero_for_callable %origin_const %float_0 %direction_const %float_9 %hitValue\n";
1338 
1339         const auto rgen = spvTemplate.specialize(subs);
1340         programCollection.spirvAsmSources.add("rgen") << rgen << spvBuildOptions;
1341 
1342         std::stringstream chit;
1343         chit << glslHeaderStr << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1344              << "hitAttributeEXT vec3 attribs;\n"
1345              << glslBindingsStr << "void main()\n"
1346              << "{\n"
1347              << "    calleeBuffer.val = 1u;\n"
1348              << "}\n";
1349         programCollection.glslSources.add("chit")
1350             << glu::ClosestHitSource(updateRayTracingGLSL(chit.str())) << buildOptions;
1351     }
1352     else if (m_params.callType == CallType::EXECUTE_CALLABLE)
1353     {
1354         subs["ENTRY_POINT"] = "RayGenerationKHR";
1355         subs["MAIN_INTERFACE_EXTRAS"] += " %callableData";
1356         subs["INTERFACE_DECORATIONS"] += "                                  OpDecorate %callableData Location 0\n";
1357         subs["INTERFACE_TYPES_AND_VARIABLES"] =
1358             "             %callable_data_ptr = OpTypePointer CallableDataKHR %float\n"
1359             "                  %callableData = OpVariable %callable_data_ptr CallableDataKHR\n";
1360         subs["CALL_STATEMENTS"] =
1361             "                                  OpExecuteCallableKHR %zero_for_callable %callableData\n";
1362 
1363         const auto rgen = spvTemplate.specialize(subs);
1364         programCollection.spirvAsmSources.add("rgen") << rgen << spvBuildOptions;
1365 
1366         std::ostringstream call;
1367         call << glslHeaderStr << "layout(location = 0) callableDataInEXT float callableData;\n"
1368              << glslBindingsStr << "void main()\n"
1369              << "{\n"
1370              << "    calleeBuffer.val = 1u;\n"
1371              << "}\n";
1372 
1373         programCollection.glslSources.add("call")
1374             << glu::CallableSource(updateRayTracingGLSL(call.str())) << buildOptions;
1375     }
1376     else if (m_params.callType == CallType::REPORT_INTERSECTION)
1377     {
1378         subs["ENTRY_POINT"] = "IntersectionKHR";
1379         subs["MAIN_INTERFACE_EXTRAS"] += " %attribs";
1380         subs["INTERFACE_DECORATIONS"] += "";
1381         subs["INTERFACE_TYPES_AND_VARIABLES"] =
1382             "             %hit_attribute_ptr = OpTypePointer HitAttributeKHR %v3float\n"
1383             "                       %attribs = OpVariable %hit_attribute_ptr HitAttributeKHR\n";
1384         subs["CALL_STATEMENTS"] =
1385             "              %intersection_ret = OpReportIntersectionKHR %bool %float_1 %zero_for_callable\n";
1386 
1387         const auto rint = spvTemplate.specialize(subs);
1388         programCollection.spirvAsmSources.add("rint") << rint << spvBuildOptions;
1389 
1390         std::ostringstream rgen;
1391         rgen << glslHeaderStr << "layout(location = 0) rayPayloadEXT vec3 hitValue;\n"
1392              << glslBindingsStr << "void main()\n"
1393              << "{\n"
1394              << "  traceRayEXT(topLevelAS, 0u, 0xFFu, 0, 0, 0, vec3(0.5, 0.5, 0.0), 0.0, vec3(0.0, 0.0, -1.0), 9.0, "
1395                 "0);\n"
1396              << "}\n";
1397         programCollection.glslSources.add("rgen")
1398             << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
1399 
1400         std::stringstream ahit;
1401         ahit << glslHeaderStr << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1402              << "hitAttributeEXT vec3 attribs;\n"
1403              << glslBindingsStr << "void main()\n"
1404              << "{\n"
1405              << "    calleeBuffer.val = 1u;\n"
1406              << "}\n";
1407         programCollection.glslSources.add("ahit")
1408             << glu::AnyHitSource(updateRayTracingGLSL(ahit.str())) << buildOptions;
1409     }
1410     else
1411     {
1412         DE_ASSERT(false);
1413     }
1414 }
1415 
1416 using v2i32 = tcu::Vector<int32_t, 2>;
1417 using v3i32 = tcu::Vector<int32_t, 3>;
1418 using v4i32 = tcu::Vector<int32_t, 4>;
1419 using a5i32 = std::array<int32_t, 5>;
1420 
1421 using v2u32 = tcu::Vector<uint32_t, 2>;
1422 using v3u32 = tcu::Vector<uint32_t, 3>;
1423 using v4u32 = tcu::Vector<uint32_t, 4>;
1424 using a5u32 = std::array<uint32_t, 5>;
1425 
1426 using v2i64 = tcu::Vector<int64_t, 2>;
1427 using v3i64 = tcu::Vector<int64_t, 3>;
1428 using v4i64 = tcu::Vector<int64_t, 4>;
1429 using a5i64 = std::array<int64_t, 5>;
1430 
1431 using v2u64 = tcu::Vector<uint64_t, 2>;
1432 using v3u64 = tcu::Vector<uint64_t, 3>;
1433 using v4u64 = tcu::Vector<uint64_t, 4>;
1434 using a5u64 = std::array<uint64_t, 5>;
1435 
1436 using v2i16 = tcu::Vector<int16_t, 2>;
1437 using v3i16 = tcu::Vector<int16_t, 3>;
1438 using v4i16 = tcu::Vector<int16_t, 4>;
1439 using a5i16 = std::array<int16_t, 5>;
1440 
1441 using v2u16 = tcu::Vector<uint16_t, 2>;
1442 using v3u16 = tcu::Vector<uint16_t, 3>;
1443 using v4u16 = tcu::Vector<uint16_t, 4>;
1444 using a5u16 = std::array<uint16_t, 5>;
1445 
1446 using v2i8 = tcu::Vector<int8_t, 2>;
1447 using v3i8 = tcu::Vector<int8_t, 3>;
1448 using v4i8 = tcu::Vector<int8_t, 4>;
1449 using a5i8 = std::array<int8_t, 5>;
1450 
1451 using v2u8 = tcu::Vector<uint8_t, 2>;
1452 using v3u8 = tcu::Vector<uint8_t, 3>;
1453 using v4u8 = tcu::Vector<uint8_t, 4>;
1454 using a5u8 = std::array<uint8_t, 5>;
1455 
1456 using v2f32 = tcu::Vector<tcu::Float32, 2>;
1457 using v3f32 = tcu::Vector<tcu::Float32, 3>;
1458 using v4f32 = tcu::Vector<tcu::Float32, 4>;
1459 using a5f32 = std::array<tcu::Float32, 5>;
1460 
1461 using v2f64 = tcu::Vector<tcu::Float64, 2>;
1462 using v3f64 = tcu::Vector<tcu::Float64, 3>;
1463 using v4f64 = tcu::Vector<tcu::Float64, 4>;
1464 using a5f64 = std::array<tcu::Float64, 5>;
1465 
1466 using v2f16 = tcu::Vector<tcu::Float16, 2>;
1467 using v3f16 = tcu::Vector<tcu::Float16, 3>;
1468 using v4f16 = tcu::Vector<tcu::Float16, 4>;
1469 using a5f16 = std::array<tcu::Float16, 5>;
1470 
1471 // Scalar types get filled with value 37, matching the value that will be substracted in the shader.
1472 #define GEN_SCALAR_FILL(DATA_TYPE)                                        \
1473     do                                                                    \
1474     {                                                                     \
1475         const auto inputBufferValue = static_cast<DATA_TYPE>(37.0);       \
1476         deMemcpy(bufferPtr, &inputBufferValue, sizeof(inputBufferValue)); \
1477     } while (0)
1478 
1479 // Vector types get filled with values that add up to 37, matching the value that will be substracted in the shader.
1480 #define GEN_V2_FILL(DATA_TYPE)                                            \
1481     do                                                                    \
1482     {                                                                     \
1483         DATA_TYPE inputBufferValue;                                       \
1484         inputBufferValue.x() = static_cast<DATA_TYPE::Element>(21.0);     \
1485         inputBufferValue.y() = static_cast<DATA_TYPE::Element>(16.0);     \
1486         deMemcpy(bufferPtr, &inputBufferValue, sizeof(inputBufferValue)); \
1487     } while (0)
1488 
1489 #define GEN_V3_FILL(DATA_TYPE)                                            \
1490     do                                                                    \
1491     {                                                                     \
1492         DATA_TYPE inputBufferValue;                                       \
1493         inputBufferValue.x() = static_cast<DATA_TYPE::Element>(11.0);     \
1494         inputBufferValue.y() = static_cast<DATA_TYPE::Element>(19.0);     \
1495         inputBufferValue.z() = static_cast<DATA_TYPE::Element>(7.0);      \
1496         deMemcpy(bufferPtr, &inputBufferValue, sizeof(inputBufferValue)); \
1497     } while (0)
1498 
1499 #define GEN_V4_FILL(DATA_TYPE)                                            \
1500     do                                                                    \
1501     {                                                                     \
1502         DATA_TYPE inputBufferValue;                                       \
1503         inputBufferValue.x() = static_cast<DATA_TYPE::Element>(9.0);      \
1504         inputBufferValue.y() = static_cast<DATA_TYPE::Element>(11.0);     \
1505         inputBufferValue.z() = static_cast<DATA_TYPE::Element>(3.0);      \
1506         inputBufferValue.w() = static_cast<DATA_TYPE::Element>(14.0);     \
1507         deMemcpy(bufferPtr, &inputBufferValue, sizeof(inputBufferValue)); \
1508     } while (0)
1509 
1510 #define GEN_A5_FILL(DATA_TYPE)                                                        \
1511     do                                                                                \
1512     {                                                                                 \
1513         DATA_TYPE inputBufferValue;                                                   \
1514         inputBufferValue[0] = static_cast<DATA_TYPE::value_type>(13.0);               \
1515         inputBufferValue[1] = static_cast<DATA_TYPE::value_type>(6.0);                \
1516         inputBufferValue[2] = static_cast<DATA_TYPE::value_type>(2.0);                \
1517         inputBufferValue[3] = static_cast<DATA_TYPE::value_type>(5.0);                \
1518         inputBufferValue[4] = static_cast<DATA_TYPE::value_type>(11.0);               \
1519         deMemcpy(bufferPtr, inputBufferValue.data(), de::dataSize(inputBufferValue)); \
1520     } while (0)
1521 
fillInputBuffer(DataType dataType,VectorType vectorType,void * bufferPtr)1522 void fillInputBuffer(DataType dataType, VectorType vectorType, void *bufferPtr)
1523 {
1524     if (vectorType == VectorType::SCALAR)
1525     {
1526         if (dataType == DataType::INT32)
1527             GEN_SCALAR_FILL(int32_t);
1528         else if (dataType == DataType::UINT32)
1529             GEN_SCALAR_FILL(uint32_t);
1530         else if (dataType == DataType::INT64)
1531             GEN_SCALAR_FILL(int64_t);
1532         else if (dataType == DataType::UINT64)
1533             GEN_SCALAR_FILL(uint64_t);
1534         else if (dataType == DataType::INT16)
1535             GEN_SCALAR_FILL(int16_t);
1536         else if (dataType == DataType::UINT16)
1537             GEN_SCALAR_FILL(uint16_t);
1538         else if (dataType == DataType::INT8)
1539             GEN_SCALAR_FILL(int8_t);
1540         else if (dataType == DataType::UINT8)
1541             GEN_SCALAR_FILL(uint8_t);
1542         else if (dataType == DataType::FLOAT32)
1543             GEN_SCALAR_FILL(tcu::Float32);
1544         else if (dataType == DataType::FLOAT64)
1545             GEN_SCALAR_FILL(tcu::Float64);
1546         else if (dataType == DataType::FLOAT16)
1547             GEN_SCALAR_FILL(tcu::Float16);
1548         else if (dataType == DataType::STRUCT)
1549         {
1550             InputStruct data = {12u, 25.0f};
1551             deMemcpy(bufferPtr, &data, sizeof(data));
1552         }
1553         else if (dataType == DataType::OP_NULL)
1554             GEN_SCALAR_FILL(uint32_t);
1555         else if (dataType == DataType::OP_UNDEF)
1556             GEN_SCALAR_FILL(uint32_t);
1557         else
1558         {
1559             DE_ASSERT(false);
1560         }
1561     }
1562     else if (vectorType == VectorType::V2)
1563     {
1564         if (dataType == DataType::INT32)
1565             GEN_V2_FILL(v2i32);
1566         else if (dataType == DataType::UINT32)
1567             GEN_V2_FILL(v2u32);
1568         else if (dataType == DataType::INT64)
1569             GEN_V2_FILL(v2i64);
1570         else if (dataType == DataType::UINT64)
1571             GEN_V2_FILL(v2u64);
1572         else if (dataType == DataType::INT16)
1573             GEN_V2_FILL(v2i16);
1574         else if (dataType == DataType::UINT16)
1575             GEN_V2_FILL(v2u16);
1576         else if (dataType == DataType::INT8)
1577             GEN_V2_FILL(v2i8);
1578         else if (dataType == DataType::UINT8)
1579             GEN_V2_FILL(v2u8);
1580         else if (dataType == DataType::FLOAT32)
1581             GEN_V2_FILL(v2f32);
1582         else if (dataType == DataType::FLOAT64)
1583             GEN_V2_FILL(v2f64);
1584         else if (dataType == DataType::FLOAT16)
1585             GEN_V2_FILL(v2f16);
1586         else
1587         {
1588             DE_ASSERT(false);
1589         }
1590     }
1591     else if (vectorType == VectorType::V3)
1592     {
1593         if (dataType == DataType::INT32)
1594             GEN_V3_FILL(v3i32);
1595         else if (dataType == DataType::UINT32)
1596             GEN_V3_FILL(v3u32);
1597         else if (dataType == DataType::INT64)
1598             GEN_V3_FILL(v3i64);
1599         else if (dataType == DataType::UINT64)
1600             GEN_V3_FILL(v3u64);
1601         else if (dataType == DataType::INT16)
1602             GEN_V3_FILL(v3i16);
1603         else if (dataType == DataType::UINT16)
1604             GEN_V3_FILL(v3u16);
1605         else if (dataType == DataType::INT8)
1606             GEN_V3_FILL(v3i8);
1607         else if (dataType == DataType::UINT8)
1608             GEN_V3_FILL(v3u8);
1609         else if (dataType == DataType::FLOAT32)
1610             GEN_V3_FILL(v3f32);
1611         else if (dataType == DataType::FLOAT64)
1612             GEN_V3_FILL(v3f64);
1613         else if (dataType == DataType::FLOAT16)
1614             GEN_V3_FILL(v3f16);
1615         else
1616         {
1617             DE_ASSERT(false);
1618         }
1619     }
1620     else if (vectorType == VectorType::V4)
1621     {
1622         if (dataType == DataType::INT32)
1623             GEN_V4_FILL(v4i32);
1624         else if (dataType == DataType::UINT32)
1625             GEN_V4_FILL(v4u32);
1626         else if (dataType == DataType::INT64)
1627             GEN_V4_FILL(v4i64);
1628         else if (dataType == DataType::UINT64)
1629             GEN_V4_FILL(v4u64);
1630         else if (dataType == DataType::INT16)
1631             GEN_V4_FILL(v4i16);
1632         else if (dataType == DataType::UINT16)
1633             GEN_V4_FILL(v4u16);
1634         else if (dataType == DataType::INT8)
1635             GEN_V4_FILL(v4i8);
1636         else if (dataType == DataType::UINT8)
1637             GEN_V4_FILL(v4u8);
1638         else if (dataType == DataType::FLOAT32)
1639             GEN_V4_FILL(v4f32);
1640         else if (dataType == DataType::FLOAT64)
1641             GEN_V4_FILL(v4f64);
1642         else if (dataType == DataType::FLOAT16)
1643             GEN_V4_FILL(v4f16);
1644         else
1645         {
1646             DE_ASSERT(false);
1647         }
1648     }
1649     else if (vectorType == VectorType::A5)
1650     {
1651         if (dataType == DataType::INT32)
1652             GEN_A5_FILL(a5i32);
1653         else if (dataType == DataType::UINT32)
1654             GEN_A5_FILL(a5u32);
1655         else if (dataType == DataType::INT64)
1656             GEN_A5_FILL(a5i64);
1657         else if (dataType == DataType::UINT64)
1658             GEN_A5_FILL(a5u64);
1659         else if (dataType == DataType::INT16)
1660             GEN_A5_FILL(a5i16);
1661         else if (dataType == DataType::UINT16)
1662             GEN_A5_FILL(a5u16);
1663         else if (dataType == DataType::INT8)
1664             GEN_A5_FILL(a5i8);
1665         else if (dataType == DataType::UINT8)
1666             GEN_A5_FILL(a5u8);
1667         else if (dataType == DataType::FLOAT32)
1668             GEN_A5_FILL(a5f32);
1669         else if (dataType == DataType::FLOAT64)
1670             GEN_A5_FILL(a5f64);
1671         else if (dataType == DataType::FLOAT16)
1672             GEN_A5_FILL(a5f16);
1673         else
1674         {
1675             DE_ASSERT(false);
1676         }
1677     }
1678     else
1679     {
1680         DE_ASSERT(false);
1681     }
1682 }
1683 
iterate(void)1684 tcu::TestStatus DataSpillTestInstance::iterate(void)
1685 {
1686     const auto &vki           = m_context.getInstanceInterface();
1687     const auto physicalDevice = m_context.getPhysicalDevice();
1688     const auto &vkd           = m_context.getDeviceInterface();
1689     const auto device         = m_context.getDevice();
1690     const auto queue          = m_context.getUniversalQueue();
1691     const auto familyIndex    = m_context.getUniversalQueueFamilyIndex();
1692     auto &alloc               = m_context.getDefaultAllocator();
1693     const auto shaderStages   = getShaderStages(m_params.callType);
1694 
1695     // Command buffer.
1696     const auto cmdPool      = makeCommandPool(vkd, device, familyIndex);
1697     const auto cmdBufferPtr = allocateCommandBuffer(vkd, device, cmdPool.get(), VK_COMMAND_BUFFER_LEVEL_PRIMARY);
1698     const auto cmdBuffer    = cmdBufferPtr.get();
1699 
1700     beginCommandBuffer(vkd, cmdBuffer);
1701 
1702     // Callee, input and output buffers.
1703     const auto calleeBufferSize = getElementSize(DataType::UINT32, VectorType::SCALAR);
1704     const auto outputBufferSize = getElementSize(DataType::UINT32, VectorType::SCALAR);
1705     const auto inputBufferSize  = getElementSize(m_params.dataType, m_params.vectorType);
1706 
1707     const auto calleeBufferInfo = makeBufferCreateInfo(calleeBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
1708     const auto outputBufferInfo = makeBufferCreateInfo(outputBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
1709     const auto inputBufferInfo  = makeBufferCreateInfo(inputBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
1710 
1711     BufferWithMemory calleeBuffer(vkd, device, alloc, calleeBufferInfo, MemoryRequirement::HostVisible);
1712     BufferWithMemory outputBuffer(vkd, device, alloc, outputBufferInfo, MemoryRequirement::HostVisible);
1713     BufferWithMemory inputBuffer(vkd, device, alloc, inputBufferInfo, MemoryRequirement::HostVisible);
1714 
1715     // Fill buffers with values.
1716     auto &calleeBufferAlloc = calleeBuffer.getAllocation();
1717     auto *calleeBufferPtr   = calleeBufferAlloc.getHostPtr();
1718     auto &outputBufferAlloc = outputBuffer.getAllocation();
1719     auto *outputBufferPtr   = outputBufferAlloc.getHostPtr();
1720     auto &inputBufferAlloc  = inputBuffer.getAllocation();
1721     auto *inputBufferPtr    = inputBufferAlloc.getHostPtr();
1722 
1723     deMemset(calleeBufferPtr, 0, static_cast<size_t>(calleeBufferSize));
1724     deMemset(outputBufferPtr, 0, static_cast<size_t>(outputBufferSize));
1725 
1726     if (samplersNeeded(m_params.dataType) || storageImageNeeded(m_params.dataType))
1727     {
1728         // The input buffer for these cases will be filled with zeros (sampling coordinates), and the input textures will contain the interesting input value.
1729         deMemset(inputBufferPtr, 0, static_cast<size_t>(inputBufferSize));
1730     }
1731     else
1732     {
1733         // We want to fill the input buffer with values that will be consistently used in the shader to obtain a result of zero.
1734         fillInputBuffer(m_params.dataType, m_params.vectorType, inputBufferPtr);
1735     }
1736 
1737     flushAlloc(vkd, device, calleeBufferAlloc);
1738     flushAlloc(vkd, device, outputBufferAlloc);
1739     flushAlloc(vkd, device, inputBufferAlloc);
1740 
1741     // Acceleration structures.
1742     de::MovePtr<BottomLevelAccelerationStructure> bottomLevelAccelerationStructure;
1743     de::MovePtr<TopLevelAccelerationStructure> topLevelAccelerationStructure;
1744 
1745     bottomLevelAccelerationStructure = makeBottomLevelAccelerationStructure();
1746     bottomLevelAccelerationStructure->setDefaultGeometryData(getShaderStageForGeometry(m_params.callType),
1747                                                              VK_GEOMETRY_NO_DUPLICATE_ANY_HIT_INVOCATION_BIT_KHR);
1748     bottomLevelAccelerationStructure->createAndBuild(vkd, device, cmdBuffer, alloc);
1749 
1750     topLevelAccelerationStructure = makeTopLevelAccelerationStructure();
1751     topLevelAccelerationStructure->setInstanceCount(1);
1752     topLevelAccelerationStructure->addInstance(
1753         de::SharedPtr<BottomLevelAccelerationStructure>(bottomLevelAccelerationStructure.release()));
1754     topLevelAccelerationStructure->createAndBuild(vkd, device, cmdBuffer, alloc);
1755 
1756     // Get some ray tracing properties.
1757     uint32_t shaderGroupHandleSize    = 0u;
1758     uint32_t shaderGroupBaseAlignment = 1u;
1759     {
1760         const auto rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
1761         shaderGroupHandleSize              = rayTracingPropertiesKHR->getShaderGroupHandleSize();
1762         shaderGroupBaseAlignment           = rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
1763     }
1764 
1765     // Textures and samplers if needed.
1766     de::MovePtr<BufferWithMemory> textureData;
1767     std::vector<de::MovePtr<ImageWithMemory>> textures;
1768     std::vector<Move<VkImageView>> textureViews;
1769     std::vector<Move<VkSampler>> samplers;
1770 
1771     if (samplersNeeded(m_params.dataType) || storageImageNeeded(m_params.dataType))
1772     {
1773         // Create texture data with the expected contents.
1774         {
1775             const auto textureDataSize       = static_cast<VkDeviceSize>(sizeof(uint32_t));
1776             const auto textureDataCreateInfo = makeBufferCreateInfo(textureDataSize, VK_BUFFER_USAGE_TRANSFER_SRC_BIT);
1777 
1778             textureData = de::MovePtr<BufferWithMemory>(
1779                 new BufferWithMemory(vkd, device, alloc, textureDataCreateInfo, MemoryRequirement::HostVisible));
1780             auto &textureDataAlloc = textureData->getAllocation();
1781             auto *textureDataPtr   = textureDataAlloc.getHostPtr();
1782 
1783             fillInputBuffer(DataType::UINT32, VectorType::SCALAR, textureDataPtr);
1784             flushAlloc(vkd, device, textureDataAlloc);
1785         }
1786 
1787         // Images will be created like this with different usages.
1788         VkImageCreateInfo imageCreateInfo = {
1789             VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // VkStructureType sType;
1790             nullptr,                             // const void* pNext;
1791             0u,                                  // VkImageCreateFlags flags;
1792             VK_IMAGE_TYPE_2D,                    // VkImageType imageType;
1793             kImageFormat,                        // VkFormat format;
1794             kImageExtent,                        // VkExtent3D extent;
1795             1u,                                  // uint32_t mipLevels;
1796             1u,                                  // uint32_t arrayLayers;
1797             VK_SAMPLE_COUNT_1_BIT,               // VkSampleCountFlagBits samples;
1798             VK_IMAGE_TILING_OPTIMAL,             // VkImageTiling tiling;
1799             kSampledImageUsage,                  // VkImageUsageFlags usage;
1800             VK_SHARING_MODE_EXCLUSIVE,           // VkSharingMode sharingMode;
1801             0u,                                  // uint32_t queueFamilyIndexCount;
1802             nullptr,                             // const uint32_t* pQueueFamilyIndices;
1803             VK_IMAGE_LAYOUT_UNDEFINED,           // VkImageLayout initialLayout;
1804         };
1805 
1806         const auto imageSubresourceRange  = makeImageSubresourceRange(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 1u, 0u, 1u);
1807         const auto imageSubresourceLayers = makeImageSubresourceLayers(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 0u, 1u);
1808 
1809         if (samplersNeeded(m_params.dataType))
1810         {
1811             // All samplers will be created like this.
1812             const VkSamplerCreateInfo samplerCreateInfo = {
1813                 VK_STRUCTURE_TYPE_SAMPLER_CREATE_INFO, // VkStructureType sType;
1814                 nullptr,                               // const void* pNext;
1815                 0u,                                    // VkSamplerCreateFlags flags;
1816                 VK_FILTER_NEAREST,                     // VkFilter magFilter;
1817                 VK_FILTER_NEAREST,                     // VkFilter minFilter;
1818                 VK_SAMPLER_MIPMAP_MODE_NEAREST,        // VkSamplerMipmapMode mipmapMode;
1819                 VK_SAMPLER_ADDRESS_MODE_REPEAT,        // VkSamplerAddressMode addressModeU;
1820                 VK_SAMPLER_ADDRESS_MODE_REPEAT,        // VkSamplerAddressMode addressModeV;
1821                 VK_SAMPLER_ADDRESS_MODE_REPEAT,        // VkSamplerAddressMode addressModeW;
1822                 0.0,                                   // float mipLodBias;
1823                 VK_FALSE,                              // VkBool32 anisotropyEnable;
1824                 1.0f,                                  // float maxAnisotropy;
1825                 VK_FALSE,                              // VkBool32 compareEnable;
1826                 VK_COMPARE_OP_ALWAYS,                  // VkCompareOp compareOp;
1827                 0.0f,                                  // float minLod;
1828                 1.0f,                                  // float maxLod;
1829                 VK_BORDER_COLOR_INT_OPAQUE_BLACK,      // VkBorderColor borderColor;
1830                 VK_FALSE,                              // VkBool32 unnormalizedCoordinates;
1831             };
1832 
1833             // Create textures and samplers.
1834             for (size_t i = 0; i < kNumImages; ++i)
1835             {
1836                 textures.emplace_back(new ImageWithMemory(vkd, device, alloc, imageCreateInfo, MemoryRequirement::Any));
1837                 textureViews.emplace_back(makeImageView(vkd, device, textures.back()->get(), VK_IMAGE_VIEW_TYPE_2D,
1838                                                         kImageFormat, imageSubresourceRange));
1839             }
1840 
1841             for (size_t i = 0; i < kNumSamplers; ++i)
1842                 samplers.emplace_back(createSampler(vkd, device, &samplerCreateInfo));
1843 
1844             // Make sure texture data is available in the transfer stage.
1845             const auto textureDataBarrier = makeMemoryBarrier(VK_ACCESS_HOST_WRITE_BIT, VK_ACCESS_TRANSFER_READ_BIT);
1846             vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_HOST_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, 0u, 1u,
1847                                    &textureDataBarrier, 0u, nullptr, 0u, nullptr);
1848 
1849             const auto bufferImageCopy = makeBufferImageCopy(kImageExtent, imageSubresourceLayers);
1850 
1851             // Fill textures with data and prepare them for the ray tracing pipeline stages.
1852             for (size_t i = 0; i < kNumImages; ++i)
1853             {
1854                 const auto texturePreCopyBarrier = makeImageMemoryBarrier(
1855                     0u, VK_ACCESS_TRANSFER_WRITE_BIT, VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
1856                     textures[i]->get(), imageSubresourceRange);
1857                 const auto texturePostCopyBarrier = makeImageMemoryBarrier(
1858                     VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_SHADER_READ_BIT, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
1859                     VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, textures[i]->get(), imageSubresourceRange);
1860 
1861                 vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, 0u,
1862                                        0u, nullptr, 0u, nullptr, 1u, &texturePreCopyBarrier);
1863                 vkd.cmdCopyBufferToImage(cmdBuffer, textureData->get(), textures[i]->get(),
1864                                          VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 1u, &bufferImageCopy);
1865                 vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT,
1866                                        VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR, 0u, 0u, nullptr, 0u, nullptr, 1u,
1867                                        &texturePostCopyBarrier);
1868             }
1869         }
1870         else if (storageImageNeeded(m_params.dataType))
1871         {
1872             // Image will be used for storage.
1873             imageCreateInfo.usage = kStorageImageUsage;
1874 
1875             textures.emplace_back(new ImageWithMemory(vkd, device, alloc, imageCreateInfo, MemoryRequirement::Any));
1876             textureViews.emplace_back(makeImageView(vkd, device, textures.back()->get(), VK_IMAGE_VIEW_TYPE_2D,
1877                                                     kImageFormat, imageSubresourceRange));
1878 
1879             // Make sure texture data is available in the transfer stage.
1880             const auto textureDataBarrier = makeMemoryBarrier(VK_ACCESS_HOST_WRITE_BIT, VK_ACCESS_TRANSFER_READ_BIT);
1881             vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_HOST_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, 0u, 1u,
1882                                    &textureDataBarrier, 0u, nullptr, 0u, nullptr);
1883 
1884             const auto bufferImageCopy       = makeBufferImageCopy(kImageExtent, imageSubresourceLayers);
1885             const auto texturePreCopyBarrier = makeImageMemoryBarrier(
1886                 0u, VK_ACCESS_TRANSFER_WRITE_BIT, VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
1887                 textures.back()->get(), imageSubresourceRange);
1888             const auto texturePostCopyBarrier = makeImageMemoryBarrier(
1889                 VK_ACCESS_TRANSFER_WRITE_BIT, (VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT),
1890                 VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, VK_IMAGE_LAYOUT_GENERAL, textures.back()->get(),
1891                 imageSubresourceRange);
1892 
1893             // Fill texture with data and prepare them for the ray tracing pipeline stages.
1894             vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, 0u, 0u,
1895                                    nullptr, 0u, nullptr, 1u, &texturePreCopyBarrier);
1896             vkd.cmdCopyBufferToImage(cmdBuffer, textureData->get(), textures.back()->get(),
1897                                      VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 1u, &bufferImageCopy);
1898             vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT,
1899                                    VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR, 0u, 0u, nullptr, 0u, nullptr, 1u,
1900                                    &texturePostCopyBarrier);
1901         }
1902         else
1903         {
1904             DE_ASSERT(false);
1905         }
1906     }
1907 
1908     // Descriptor set layout.
1909     DescriptorSetLayoutBuilder dslBuilder;
1910     dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, 1u, shaderStages, nullptr);
1911     dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1u, shaderStages, nullptr); // Callee buffer.
1912     dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1u, shaderStages, nullptr); // Output buffer.
1913     dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1u, shaderStages, nullptr); // Input buffer.
1914     if (samplersNeeded(m_params.dataType))
1915     {
1916         dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE, 2u, shaderStages, nullptr);
1917         dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_SAMPLER, 2u, shaderStages, nullptr);
1918         dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, 2u, shaderStages, nullptr);
1919     }
1920     else if (storageImageNeeded(m_params.dataType))
1921     {
1922         dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, 1u, shaderStages, nullptr);
1923     }
1924     const auto descriptorSetLayout = dslBuilder.build(vkd, device);
1925 
1926     // Pipeline layout.
1927     const auto pipelineLayout = makePipelineLayout(vkd, device, descriptorSetLayout.get());
1928 
1929     // Descriptor pool and set.
1930     DescriptorPoolBuilder poolBuilder;
1931     poolBuilder.addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR);
1932     poolBuilder.addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 3u);
1933     if (samplersNeeded(m_params.dataType))
1934     {
1935         poolBuilder.addType(VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE, 2u);
1936         poolBuilder.addType(VK_DESCRIPTOR_TYPE_SAMPLER, 2u);
1937         poolBuilder.addType(VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, 2u);
1938     }
1939     else if (storageImageNeeded(m_params.dataType))
1940     {
1941         poolBuilder.addType(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, 1u);
1942     }
1943     const auto descriptorPool = poolBuilder.build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
1944     const auto descriptorSet  = makeDescriptorSet(vkd, device, descriptorPool.get(), descriptorSetLayout.get());
1945 
1946     // Update descriptor set.
1947     {
1948         const VkWriteDescriptorSetAccelerationStructureKHR writeASInfo = {
1949             VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR,
1950             nullptr,
1951             1u,
1952             topLevelAccelerationStructure.get()->getPtr(),
1953         };
1954 
1955         DescriptorSetUpdateBuilder updateBuilder;
1956 
1957         const auto ds = descriptorSet.get();
1958 
1959         const auto calleeBufferDescriptorInfo = makeDescriptorBufferInfo(calleeBuffer.get(), 0ull, VK_WHOLE_SIZE);
1960         const auto outputBufferDescriptorInfo = makeDescriptorBufferInfo(outputBuffer.get(), 0ull, VK_WHOLE_SIZE);
1961         const auto inputBufferDescriptorInfo  = makeDescriptorBufferInfo(inputBuffer.get(), 0ull, VK_WHOLE_SIZE);
1962 
1963         updateBuilder.writeSingle(ds, DescriptorSetUpdateBuilder::Location::binding(0u),
1964                                   VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &writeASInfo);
1965         updateBuilder.writeSingle(ds, DescriptorSetUpdateBuilder::Location::binding(1u),
1966                                   VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &calleeBufferDescriptorInfo);
1967         updateBuilder.writeSingle(ds, DescriptorSetUpdateBuilder::Location::binding(2u),
1968                                   VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &outputBufferDescriptorInfo);
1969         updateBuilder.writeSingle(ds, DescriptorSetUpdateBuilder::Location::binding(3u),
1970                                   VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &inputBufferDescriptorInfo);
1971 
1972         if (samplersNeeded(m_params.dataType))
1973         {
1974             // Update textures, samplers and combined image samplers.
1975             std::vector<VkDescriptorImageInfo> textureDescInfos;
1976             std::vector<VkDescriptorImageInfo> textureSamplerInfos;
1977             std::vector<VkDescriptorImageInfo> combinedSamplerInfos;
1978 
1979             for (size_t i = 0; i < kNumAloneImages; ++i)
1980                 textureDescInfos.push_back(
1981                     makeDescriptorImageInfo(DE_NULL, textureViews[i].get(), VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL));
1982             for (size_t i = 0; i < kNumAloneSamplers; ++i)
1983                 textureSamplerInfos.push_back(
1984                     makeDescriptorImageInfo(samplers[i].get(), DE_NULL, VK_IMAGE_LAYOUT_UNDEFINED));
1985 
1986             for (size_t i = 0; i < kNumCombined; ++i)
1987                 combinedSamplerInfos.push_back(makeDescriptorImageInfo(samplers[i + kNumAloneSamplers].get(),
1988                                                                        textureViews[i + kNumAloneImages].get(),
1989                                                                        VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL));
1990 
1991             updateBuilder.writeArray(ds, DescriptorSetUpdateBuilder::Location::binding(4u),
1992                                      VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE, kNumAloneImages, textureDescInfos.data());
1993             updateBuilder.writeArray(ds, DescriptorSetUpdateBuilder::Location::binding(5u), VK_DESCRIPTOR_TYPE_SAMPLER,
1994                                      kNumAloneSamplers, textureSamplerInfos.data());
1995             updateBuilder.writeArray(ds, DescriptorSetUpdateBuilder::Location::binding(6u),
1996                                      VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, kNumCombined,
1997                                      combinedSamplerInfos.data());
1998         }
1999         else if (storageImageNeeded(m_params.dataType))
2000         {
2001             const auto storageImageDescriptorInfo =
2002                 makeDescriptorImageInfo(DE_NULL, textureViews.back().get(), VK_IMAGE_LAYOUT_GENERAL);
2003             updateBuilder.writeSingle(ds, DescriptorSetUpdateBuilder::Location::binding(4u),
2004                                       VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, &storageImageDescriptorInfo);
2005         }
2006 
2007         updateBuilder.update(vkd, device);
2008     }
2009 
2010     // Create raytracing pipeline and shader binding tables.
2011     Move<VkPipeline> pipeline;
2012 
2013     de::MovePtr<BufferWithMemory> raygenShaderBindingTable;
2014     de::MovePtr<BufferWithMemory> missShaderBindingTable;
2015     de::MovePtr<BufferWithMemory> hitShaderBindingTable;
2016     de::MovePtr<BufferWithMemory> callableShaderBindingTable;
2017 
2018     VkStridedDeviceAddressRegionKHR raygenShaderBindingTableRegion   = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
2019     VkStridedDeviceAddressRegionKHR missShaderBindingTableRegion     = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
2020     VkStridedDeviceAddressRegionKHR hitShaderBindingTableRegion      = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
2021     VkStridedDeviceAddressRegionKHR callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
2022 
2023     {
2024         const auto rayTracingPipeline = de::newMovePtr<RayTracingPipeline>();
2025         const auto callType           = m_params.callType;
2026 
2027         // Every case uses a ray generation shader.
2028         rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR,
2029                                       createShaderModule(vkd, device, m_context.getBinaryCollection().get("rgen"), 0),
2030                                       0);
2031 
2032         if (callType == CallType::TRACE_RAY)
2033         {
2034             rayTracingPipeline->addShader(
2035                 VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,
2036                 createShaderModule(vkd, device, m_context.getBinaryCollection().get("chit"), 0), 1);
2037         }
2038         else if (callType == CallType::EXECUTE_CALLABLE)
2039         {
2040             rayTracingPipeline->addShader(
2041                 VK_SHADER_STAGE_CALLABLE_BIT_KHR,
2042                 createShaderModule(vkd, device, m_context.getBinaryCollection().get("call"), 0), 1);
2043         }
2044         else if (callType == CallType::REPORT_INTERSECTION)
2045         {
2046             rayTracingPipeline->addShader(
2047                 VK_SHADER_STAGE_INTERSECTION_BIT_KHR,
2048                 createShaderModule(vkd, device, m_context.getBinaryCollection().get("rint"), 0), 1);
2049             rayTracingPipeline->addShader(
2050                 VK_SHADER_STAGE_ANY_HIT_BIT_KHR,
2051                 createShaderModule(vkd, device, m_context.getBinaryCollection().get("ahit"), 0), 1);
2052         }
2053         else
2054         {
2055             DE_ASSERT(false);
2056         }
2057 
2058         pipeline = rayTracingPipeline->createPipeline(vkd, device, pipelineLayout.get());
2059 
2060         raygenShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
2061             vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
2062         raygenShaderBindingTableRegion =
2063             makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0),
2064                                               shaderGroupHandleSize, shaderGroupHandleSize);
2065 
2066         if (callType == CallType::EXECUTE_CALLABLE)
2067         {
2068             callableShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
2069                 vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
2070             callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(
2071                 getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize,
2072                 shaderGroupHandleSize);
2073         }
2074         else if (callType == CallType::TRACE_RAY || callType == CallType::REPORT_INTERSECTION)
2075         {
2076             hitShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
2077                 vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
2078             hitShaderBindingTableRegion =
2079                 makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0),
2080                                                   shaderGroupHandleSize, shaderGroupHandleSize);
2081         }
2082         else
2083         {
2084             DE_ASSERT(false);
2085         }
2086     }
2087 
2088     // Use ray tracing pipeline.
2089     vkd.cmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipeline.get());
2090     vkd.cmdBindDescriptorSets(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipelineLayout.get(), 0u, 1u,
2091                               &descriptorSet.get(), 0u, nullptr);
2092     vkd.cmdTraceRaysKHR(cmdBuffer, &raygenShaderBindingTableRegion, &missShaderBindingTableRegion,
2093                         &hitShaderBindingTableRegion, &callableShaderBindingTableRegion, 1u, 1u, 1u);
2094 
2095     // Synchronize output and callee buffers.
2096     const auto memBarrier = makeMemoryBarrier(VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
2097     vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR, VK_PIPELINE_STAGE_HOST_BIT, 0u, 1u,
2098                            &memBarrier, 0u, nullptr, 0u, nullptr);
2099 
2100     endCommandBuffer(vkd, cmdBuffer);
2101     submitCommandsAndWait(vkd, device, queue, cmdBuffer);
2102 
2103     // Verify output and callee buffers.
2104     invalidateAlloc(vkd, device, outputBufferAlloc);
2105     invalidateAlloc(vkd, device, calleeBufferAlloc);
2106 
2107     std::map<std::string, void *> bufferPtrs;
2108     bufferPtrs["output"] = outputBufferPtr;
2109     bufferPtrs["callee"] = calleeBufferPtr;
2110 
2111     for (const auto &ptr : bufferPtrs)
2112     {
2113         const auto &bufferName = ptr.first;
2114         const auto &bufferPtr  = ptr.second;
2115 
2116         uint32_t outputVal;
2117         deMemcpy(&outputVal, bufferPtr, sizeof(outputVal));
2118 
2119         if (outputVal != 1u)
2120             return tcu::TestStatus::fail("Unexpected value found in " + bufferName +
2121                                          " buffer: " + de::toString(outputVal));
2122     }
2123 
2124     return tcu::TestStatus::pass("Pass");
2125 }
2126 
2127 enum class InterfaceType
2128 {
2129     RAY_PAYLOAD = 0,
2130     CALLABLE_DATA,
2131     HIT_ATTRIBUTES,
2132     SHADER_RECORD_BUFFER_RGEN,
2133     SHADER_RECORD_BUFFER_CALL,
2134     SHADER_RECORD_BUFFER_MISS,
2135     SHADER_RECORD_BUFFER_HIT,
2136 };
2137 
2138 // Separate class to ease testing pipeline interface variables.
2139 class DataSpillPipelineInterfaceTestCase : public vkt::TestCase
2140 {
2141 public:
2142     struct TestParams
2143     {
2144         InterfaceType interfaceType;
2145     };
2146 
2147     DataSpillPipelineInterfaceTestCase(tcu::TestContext &testCtx, const std::string &name,
2148                                        const TestParams &testParams);
~DataSpillPipelineInterfaceTestCase(void)2149     virtual ~DataSpillPipelineInterfaceTestCase(void)
2150     {
2151     }
2152 
2153     virtual void initPrograms(vk::SourceCollections &programCollection) const;
2154     virtual TestInstance *createInstance(Context &context) const;
2155     virtual void checkSupport(Context &context) const;
2156 
2157 private:
2158     TestParams m_params;
2159 };
2160 
2161 class DataSpillPipelineInterfaceTestInstance : public vkt::TestInstance
2162 {
2163 public:
2164     using TestParams = DataSpillPipelineInterfaceTestCase::TestParams;
2165 
2166     DataSpillPipelineInterfaceTestInstance(Context &context, const TestParams &testParams);
~DataSpillPipelineInterfaceTestInstance(void)2167     ~DataSpillPipelineInterfaceTestInstance(void)
2168     {
2169     }
2170 
2171     tcu::TestStatus iterate(void);
2172 
2173 private:
2174     TestParams m_params;
2175 };
2176 
DataSpillPipelineInterfaceTestCase(tcu::TestContext & testCtx,const std::string & name,const TestParams & testParams)2177 DataSpillPipelineInterfaceTestCase::DataSpillPipelineInterfaceTestCase(tcu::TestContext &testCtx,
2178                                                                        const std::string &name,
2179                                                                        const TestParams &testParams)
2180     : vkt::TestCase(testCtx, name)
2181     , m_params(testParams)
2182 {
2183 }
2184 
createInstance(Context & context) const2185 TestInstance *DataSpillPipelineInterfaceTestCase::createInstance(Context &context) const
2186 {
2187     return new DataSpillPipelineInterfaceTestInstance(context, m_params);
2188 }
2189 
DataSpillPipelineInterfaceTestInstance(Context & context,const TestParams & testParams)2190 DataSpillPipelineInterfaceTestInstance::DataSpillPipelineInterfaceTestInstance(Context &context,
2191                                                                                const TestParams &testParams)
2192     : vkt::TestInstance(context)
2193     , m_params(testParams)
2194 {
2195 }
2196 
checkSupport(Context & context) const2197 void DataSpillPipelineInterfaceTestCase::checkSupport(Context &context) const
2198 {
2199     commonCheckSupport(context);
2200 }
2201 
initPrograms(vk::SourceCollections & programCollection) const2202 void DataSpillPipelineInterfaceTestCase::initPrograms(vk::SourceCollections &programCollection) const
2203 {
2204     const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
2205 
2206     const std::string glslHeader = "#version 460 core\n"
2207                                    "#extension GL_EXT_ray_tracing : require\n";
2208 
2209     const std::string glslBindings = "layout(set = 0, binding = 0) uniform accelerationStructureEXT topLevelAS;\n"
2210                                      "layout(set = 0, binding = 1) buffer StorageBlock { uint val[" +
2211                                      std::to_string(kNumStorageValues) + "]; } storageBuffer;\n";
2212 
2213     if (m_params.interfaceType == InterfaceType::RAY_PAYLOAD)
2214     {
2215         // The closest hit shader will store 100 in the second array position.
2216         // The ray gen shader will store 103 in the first array position using the hitValue after the traceRayExt() call.
2217 
2218         std::ostringstream rgen;
2219         rgen << glslHeader << "layout(location = 0) rayPayloadEXT vec3 hitValue;\n"
2220              << glslBindings << "void main()\n"
2221              << "{\n"
2222              << "  hitValue = vec3(10.0, 30.0, 60.0);\n"
2223              << "  traceRayEXT(topLevelAS, 0u, 0xFFu, 0, 0, 0, vec3(0.5, 0.5, 0.0), 0.0, vec3(0.0, 0.0, -1.0), 9.0, "
2224                 "0);\n"
2225              << "  storageBuffer.val[0] = uint(hitValue.x + hitValue.y + hitValue.z);\n"
2226              << "}\n";
2227         programCollection.glslSources.add("rgen")
2228             << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
2229 
2230         std::stringstream chit;
2231         chit << glslHeader << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
2232              << "hitAttributeEXT vec3 attribs;\n"
2233              << glslBindings << "void main()\n"
2234              << "{\n"
2235              << "  storageBuffer.val[1] = uint(hitValue.x + hitValue.y + hitValue.z);\n"
2236              << "  hitValue = vec3(hitValue.x + 1.0, hitValue.y + 1.0, hitValue.z + 1.0);\n"
2237              << "}\n";
2238         programCollection.glslSources.add("chit")
2239             << glu::ClosestHitSource(updateRayTracingGLSL(chit.str())) << buildOptions;
2240     }
2241     else if (m_params.interfaceType == InterfaceType::CALLABLE_DATA)
2242     {
2243         // The callable shader shader will store 100 in the second array position.
2244         // The ray gen shader will store 200 in the first array position using the callable data after the executeCallableEXT() call.
2245 
2246         std::ostringstream rgen;
2247         rgen << glslHeader << "layout(location = 0) callableDataEXT float callableData;\n"
2248              << glslBindings << "void main()\n"
2249              << "{\n"
2250              << "  callableData = 100.0;\n"
2251              << "  executeCallableEXT(0, 0);\n"
2252              << "  storageBuffer.val[0] = uint(callableData);\n"
2253              << "}\n";
2254         programCollection.glslSources.add("rgen")
2255             << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
2256 
2257         std::ostringstream call;
2258         call << glslHeader << "layout(location = 0) callableDataInEXT float callableData;\n"
2259              << glslBindings << "void main()\n"
2260              << "{\n"
2261              << "    storageBuffer.val[1] = uint(callableData);\n"
2262              << "    callableData = callableData * 2.0;\n"
2263              << "}\n";
2264 
2265         programCollection.glslSources.add("call")
2266             << glu::CallableSource(updateRayTracingGLSL(call.str())) << buildOptions;
2267     }
2268     else if (m_params.interfaceType == InterfaceType::HIT_ATTRIBUTES)
2269     {
2270         // The ray gen shader will store value 300 in the first storage buffer position.
2271         // The intersection shader will store value 315 in the second storage buffer position.
2272         // The closes hit shader will store value 330 in the third storage buffer position using the hit attributes.
2273 
2274         std::ostringstream rgen;
2275         rgen << glslHeader << "layout(location = 0) rayPayloadEXT vec3 hitValue;\n"
2276              << glslBindings << "void main()\n"
2277              << "{\n"
2278              << "  traceRayEXT(topLevelAS, 0u, 0xFFu, 0, 0, 0, vec3(0.5, 0.5, 0.0), 0.0, vec3(0.0, 0.0, -1.0), 9.0, "
2279                 "0);\n"
2280              << "  storageBuffer.val[0] = 300u;\n"
2281              << "}\n";
2282         programCollection.glslSources.add("rgen")
2283             << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
2284 
2285         std::stringstream rint;
2286         rint << glslHeader << "hitAttributeEXT vec3 attribs;\n"
2287              << glslBindings << "void main()\n"
2288              << "{\n"
2289              << "  attribs = vec3(140.0, 160.0, 30.0);\n"
2290              << "  storageBuffer.val[1] = 315u;\n"
2291              << "  reportIntersectionEXT(1.0f, 0);\n"
2292              << "}\n";
2293 
2294         programCollection.glslSources.add("rint")
2295             << glu::IntersectionSource(updateRayTracingGLSL(rint.str())) << buildOptions;
2296 
2297         std::stringstream chit;
2298         chit << glslHeader << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
2299              << "hitAttributeEXT vec3 attribs;\n"
2300              << glslBindings << "void main()\n"
2301              << "{\n"
2302              << "  storageBuffer.val[2] = uint(attribs.x + attribs.y + attribs.z);\n"
2303              << "}\n";
2304         programCollection.glslSources.add("chit")
2305             << glu::ClosestHitSource(updateRayTracingGLSL(chit.str())) << buildOptions;
2306     }
2307     else if (m_params.interfaceType == InterfaceType::SHADER_RECORD_BUFFER_RGEN)
2308     {
2309         // The ray gen shader will have a uvec4 in the shader record buffer with contents 400, 401, 402, 403.
2310         // The shader will call a callable shader indicating a position in that vec4 (0, 1, 2, 3). For example, let's use position 1.
2311         // The callable shader will return the indicated position+1 modulo 4, so it will return 2 in our case.
2312         // *After* returning from the callable shader, the raygen shader will use that reply to access position 2 and write a 402 in the first output buffer position.
2313         // The callable shader will store 450 in the second output buffer position.
2314 
2315         std::ostringstream rgen;
2316         rgen << glslHeader << "layout(shaderRecordEXT) buffer ShaderRecordStruct {\n"
2317              << "  uvec4 info;\n"
2318              << "};\n"
2319              << "layout(location = 0) callableDataEXT uint callableData;\n"
2320              << glslBindings << "void main()\n"
2321              << "{\n"
2322              << "  callableData = 1u;"
2323              << "  executeCallableEXT(0, 0);\n"
2324              << "  if      (callableData == 0u) storageBuffer.val[0] = info.x;\n"
2325              << "  else if (callableData == 1u) storageBuffer.val[0] = info.y;\n"
2326              << "  else if (callableData == 2u) storageBuffer.val[0] = info.z;\n"
2327              << "  else if (callableData == 3u) storageBuffer.val[0] = info.w;\n"
2328              << "}\n";
2329         programCollection.glslSources.add("rgen")
2330             << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
2331 
2332         std::ostringstream call;
2333         call << glslHeader << "layout(location = 0) callableDataInEXT uint callableData;\n"
2334              << glslBindings << "void main()\n"
2335              << "{\n"
2336              << "    storageBuffer.val[1] = 450u;\n"
2337              << "    callableData = (callableData + 1u) % 4u;\n"
2338              << "}\n";
2339 
2340         programCollection.glslSources.add("call")
2341             << glu::CallableSource(updateRayTracingGLSL(call.str())) << buildOptions;
2342     }
2343     else if (m_params.interfaceType == InterfaceType::SHADER_RECORD_BUFFER_CALL)
2344     {
2345         // Similar to the previous case, with a twist:
2346         //   * rgen passes the vector position.
2347         //   * call increases that by one.
2348         //   * subcall increases again and does the modulo operation, also writing 450 in the third output buffer value.
2349         //   * call is the one accessing the vector at the returned position, writing 403 in this case to the second output buffer value.
2350         //   * call passes this value back doubled to rgen, which writes it to the first output buffer value (806).
2351 
2352         std::ostringstream rgen;
2353         rgen << glslHeader << "layout(location = 0) callableDataEXT uint callableData;\n"
2354              << glslBindings << "void main()\n"
2355              << "{\n"
2356              << "  callableData = 1u;\n"
2357              << "  executeCallableEXT(0, 0);\n"
2358              << "  storageBuffer.val[0] = callableData;\n"
2359              << "}\n";
2360         programCollection.glslSources.add("rgen")
2361             << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
2362 
2363         std::ostringstream call;
2364         call << glslHeader << "layout(shaderRecordEXT) buffer ShaderRecordStruct {\n"
2365              << "  uvec4 info;\n"
2366              << "};\n"
2367              << "layout(location = 0) callableDataInEXT uint callableDataIn;\n"
2368              << "layout(location = 1) callableDataEXT uint callableDataOut;\n"
2369              << glslBindings << "void main()\n"
2370              << "{\n"
2371              << "  callableDataOut = callableDataIn + 1u;\n"
2372              << "  executeCallableEXT(1, 1);\n"
2373              << "  uint outputBufferValue = 777u;\n"
2374              << "  if      (callableDataOut == 0u) outputBufferValue = info.x;\n"
2375              << "  else if (callableDataOut == 1u) outputBufferValue = info.y;\n"
2376              << "  else if (callableDataOut == 2u) outputBufferValue = info.z;\n"
2377              << "  else if (callableDataOut == 3u) outputBufferValue = info.w;\n"
2378              << "  storageBuffer.val[1] = outputBufferValue;\n"
2379              << "  callableDataIn = outputBufferValue * 2u;\n"
2380              << "}\n";
2381 
2382         programCollection.glslSources.add("call")
2383             << glu::CallableSource(updateRayTracingGLSL(call.str())) << buildOptions;
2384 
2385         std::ostringstream subcall;
2386         subcall << glslHeader << "layout(location = 1) callableDataInEXT uint callableData;\n"
2387                 << glslBindings << "void main()\n"
2388                 << "{\n"
2389                 << "  callableData = (callableData + 1u) % 4u;\n"
2390                 << "  storageBuffer.val[2] = 450u;\n"
2391                 << "}\n";
2392 
2393         programCollection.glslSources.add("subcall")
2394             << glu::CallableSource(updateRayTracingGLSL(subcall.str())) << buildOptions;
2395     }
2396     else if (m_params.interfaceType == InterfaceType::SHADER_RECORD_BUFFER_MISS ||
2397              m_params.interfaceType == InterfaceType::SHADER_RECORD_BUFFER_HIT)
2398     {
2399         // Similar to the previous one, but the intermediate call shader has been replaced with a miss or closest hit shader.
2400         // The rgen shader will communicate with the miss/chit shader using the ray payload instead of the callable data.
2401         // Also, the initial position will be 2, so it will wrap around in this case. The numbers will also change.
2402 
2403         std::ostringstream rgen;
2404         rgen << glslHeader << "layout(location = 0) rayPayloadEXT uint rayPayload;\n"
2405              << glslBindings << "void main()\n"
2406              << "{\n"
2407              << "  rayPayload = 2u;\n"
2408              << "  traceRayEXT(topLevelAS, 0u, 0xFFu, 0, 0, 0, vec3(0.5, 0.5, 0.0), 0.0, vec3(0.0, 0.0, -1.0), 9.0, "
2409                 "0);\n"
2410              << "  storageBuffer.val[0] = rayPayload;\n"
2411              << "}\n";
2412         programCollection.glslSources.add("rgen")
2413             << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
2414 
2415         std::ostringstream chitOrMiss;
2416         chitOrMiss << glslHeader << "layout(shaderRecordEXT) buffer ShaderRecordStruct {\n"
2417                    << "  uvec4 info;\n"
2418                    << "};\n"
2419                    << "layout(location = 0) rayPayloadInEXT uint rayPayload;\n"
2420                    << "layout(location = 0) callableDataEXT uint callableData;\n"
2421                    << glslBindings << "void main()\n"
2422                    << "{\n"
2423                    << "  callableData = rayPayload + 1u;\n"
2424                    << "  executeCallableEXT(0, 0);\n"
2425                    << "  uint outputBufferValue = 777u;\n"
2426                    << "  if      (callableData == 0u) outputBufferValue = info.x;\n"
2427                    << "  else if (callableData == 1u) outputBufferValue = info.y;\n"
2428                    << "  else if (callableData == 2u) outputBufferValue = info.z;\n"
2429                    << "  else if (callableData == 3u) outputBufferValue = info.w;\n"
2430                    << "  storageBuffer.val[1] = outputBufferValue;\n"
2431                    << "  rayPayload = outputBufferValue * 3u;\n"
2432                    << "}\n";
2433 
2434         if (m_params.interfaceType == InterfaceType::SHADER_RECORD_BUFFER_MISS)
2435             programCollection.glslSources.add("miss")
2436                 << glu::MissSource(updateRayTracingGLSL(chitOrMiss.str())) << buildOptions;
2437         else if (m_params.interfaceType == InterfaceType::SHADER_RECORD_BUFFER_HIT)
2438             programCollection.glslSources.add("chit")
2439                 << glu::ClosestHitSource(updateRayTracingGLSL(chitOrMiss.str())) << buildOptions;
2440         else
2441             DE_ASSERT(false);
2442 
2443         std::ostringstream call;
2444         call << glslHeader << "layout(location = 0) callableDataInEXT uint callableData;\n"
2445              << glslBindings << "void main()\n"
2446              << "{\n"
2447              << "    storageBuffer.val[2] = 490u;\n"
2448              << "    callableData = (callableData + 1u) % 4u;\n"
2449              << "}\n";
2450 
2451         programCollection.glslSources.add("call")
2452             << glu::CallableSource(updateRayTracingGLSL(call.str())) << buildOptions;
2453     }
2454     else
2455     {
2456         DE_ASSERT(false);
2457     }
2458 }
2459 
getShaderStages(InterfaceType type_)2460 VkShaderStageFlags getShaderStages(InterfaceType type_)
2461 {
2462     VkShaderStageFlags flags = VK_SHADER_STAGE_RAYGEN_BIT_KHR;
2463 
2464     switch (type_)
2465     {
2466     case InterfaceType::HIT_ATTRIBUTES:
2467         flags |= VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
2468         // fallthrough.
2469     case InterfaceType::RAY_PAYLOAD:
2470         flags |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
2471         break;
2472     case InterfaceType::CALLABLE_DATA:
2473     case InterfaceType::SHADER_RECORD_BUFFER_RGEN:
2474     case InterfaceType::SHADER_RECORD_BUFFER_CALL:
2475         flags |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
2476         break;
2477     case InterfaceType::SHADER_RECORD_BUFFER_MISS:
2478         flags |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
2479         flags |= VK_SHADER_STAGE_MISS_BIT_KHR;
2480         break;
2481     case InterfaceType::SHADER_RECORD_BUFFER_HIT:
2482         flags |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
2483         flags |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
2484         break;
2485     default:
2486         DE_ASSERT(false);
2487         break;
2488     }
2489 
2490     return flags;
2491 }
2492 
2493 // Proper stage for generating default geometry.
getShaderStageForGeometry(InterfaceType type_)2494 VkShaderStageFlagBits getShaderStageForGeometry(InterfaceType type_)
2495 {
2496     VkShaderStageFlagBits bits = VK_SHADER_STAGE_FLAG_BITS_MAX_ENUM;
2497 
2498     switch (type_)
2499     {
2500     case InterfaceType::HIT_ATTRIBUTES:
2501         bits = VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
2502         break;
2503     case InterfaceType::RAY_PAYLOAD:
2504         bits = VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
2505         break;
2506     case InterfaceType::CALLABLE_DATA:
2507         bits = VK_SHADER_STAGE_CALLABLE_BIT_KHR;
2508         break;
2509     case InterfaceType::SHADER_RECORD_BUFFER_RGEN:
2510         bits = VK_SHADER_STAGE_CALLABLE_BIT_KHR;
2511         break;
2512     case InterfaceType::SHADER_RECORD_BUFFER_CALL:
2513         bits = VK_SHADER_STAGE_CALLABLE_BIT_KHR;
2514         break;
2515     case InterfaceType::SHADER_RECORD_BUFFER_MISS:
2516         bits = VK_SHADER_STAGE_MISS_BIT_KHR;
2517         break;
2518     case InterfaceType::SHADER_RECORD_BUFFER_HIT:
2519         bits = VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
2520         break;
2521     default:
2522         DE_ASSERT(false);
2523         break;
2524     }
2525 
2526     DE_ASSERT(bits != VK_SHADER_STAGE_FLAG_BITS_MAX_ENUM);
2527     return bits;
2528 }
2529 
createSBTWithShaderRecord(const DeviceInterface & vkd,VkDevice device,vk::Allocator & alloc,VkPipeline pipeline,RayTracingPipeline * rayTracingPipeline,uint32_t shaderGroupHandleSize,uint32_t shaderGroupBaseAlignment,uint32_t firstGroup,uint32_t groupCount,de::MovePtr<BufferWithMemory> & shaderBindingTable,VkStridedDeviceAddressRegionKHR & shaderBindingTableRegion)2530 void createSBTWithShaderRecord(const DeviceInterface &vkd, VkDevice device, vk::Allocator &alloc, VkPipeline pipeline,
2531                                RayTracingPipeline *rayTracingPipeline, uint32_t shaderGroupHandleSize,
2532                                uint32_t shaderGroupBaseAlignment, uint32_t firstGroup, uint32_t groupCount,
2533                                de::MovePtr<BufferWithMemory> &shaderBindingTable,
2534                                VkStridedDeviceAddressRegionKHR &shaderBindingTableRegion)
2535 {
2536     const auto alignedSize = de::roundUp(shaderGroupHandleSize + kShaderRecordSize, shaderGroupHandleSize);
2537     shaderBindingTable     = rayTracingPipeline->createShaderBindingTable(
2538         vkd, device, pipeline, alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, firstGroup, groupCount, 0u, 0u,
2539         MemoryRequirement::Any, 0u, 0u, kShaderRecordSize);
2540     shaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(
2541         getBufferDeviceAddress(vkd, device, shaderBindingTable->get(), 0), alignedSize, groupCount * alignedSize);
2542 
2543     // Fill shader record buffer data.
2544     // Note we will only fill the first shader record after the handle.
2545     const tcu::UVec4 shaderRecordData(400u, 401u, 402u, 403u);
2546     auto &sbtAlloc = shaderBindingTable->getAllocation();
2547     auto *dataPtr  = reinterpret_cast<uint8_t *>(sbtAlloc.getHostPtr()) + shaderGroupHandleSize;
2548 
2549     DE_STATIC_ASSERT(sizeof(shaderRecordData) == static_cast<size_t>(kShaderRecordSize));
2550     deMemcpy(dataPtr, &shaderRecordData, sizeof(shaderRecordData));
2551 }
2552 
iterate(void)2553 tcu::TestStatus DataSpillPipelineInterfaceTestInstance::iterate(void)
2554 {
2555     const auto &vki           = m_context.getInstanceInterface();
2556     const auto physicalDevice = m_context.getPhysicalDevice();
2557     const auto &vkd           = m_context.getDeviceInterface();
2558     const auto device         = m_context.getDevice();
2559     const auto queue          = m_context.getUniversalQueue();
2560     const auto familyIndex    = m_context.getUniversalQueueFamilyIndex();
2561     auto &alloc               = m_context.getDefaultAllocator();
2562     const auto shaderStages   = getShaderStages(m_params.interfaceType);
2563 
2564     // Command buffer.
2565     const auto cmdPool      = makeCommandPool(vkd, device, familyIndex);
2566     const auto cmdBufferPtr = allocateCommandBuffer(vkd, device, cmdPool.get(), VK_COMMAND_BUFFER_LEVEL_PRIMARY);
2567     const auto cmdBuffer    = cmdBufferPtr.get();
2568 
2569     beginCommandBuffer(vkd, cmdBuffer);
2570 
2571     // Storage buffer.
2572     std::array<uint32_t, kNumStorageValues> storageBufferData;
2573     const auto storageBufferSize = de::dataSize(storageBufferData);
2574     const auto storagebufferInfo = makeBufferCreateInfo(storageBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
2575     BufferWithMemory storageBuffer(vkd, device, alloc, storagebufferInfo, MemoryRequirement::HostVisible);
2576 
2577     // Zero-out buffer.
2578     auto &storageBufferAlloc = storageBuffer.getAllocation();
2579     auto *storageBufferPtr   = storageBufferAlloc.getHostPtr();
2580     deMemset(storageBufferPtr, 0, storageBufferSize);
2581     flushAlloc(vkd, device, storageBufferAlloc);
2582 
2583     // Acceleration structures.
2584     de::MovePtr<BottomLevelAccelerationStructure> bottomLevelAccelerationStructure;
2585     de::MovePtr<TopLevelAccelerationStructure> topLevelAccelerationStructure;
2586 
2587     bottomLevelAccelerationStructure = makeBottomLevelAccelerationStructure();
2588     bottomLevelAccelerationStructure->setDefaultGeometryData(getShaderStageForGeometry(m_params.interfaceType),
2589                                                              VK_GEOMETRY_NO_DUPLICATE_ANY_HIT_INVOCATION_BIT_KHR);
2590     bottomLevelAccelerationStructure->createAndBuild(vkd, device, cmdBuffer, alloc);
2591 
2592     topLevelAccelerationStructure = makeTopLevelAccelerationStructure();
2593     topLevelAccelerationStructure->setInstanceCount(1);
2594     topLevelAccelerationStructure->addInstance(
2595         de::SharedPtr<BottomLevelAccelerationStructure>(bottomLevelAccelerationStructure.release()));
2596     topLevelAccelerationStructure->createAndBuild(vkd, device, cmdBuffer, alloc);
2597 
2598     // Get some ray tracing properties.
2599     uint32_t shaderGroupHandleSize    = 0u;
2600     uint32_t shaderGroupBaseAlignment = 1u;
2601     {
2602         const auto rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
2603         shaderGroupHandleSize              = rayTracingPropertiesKHR->getShaderGroupHandleSize();
2604         shaderGroupBaseAlignment           = rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
2605     }
2606 
2607     // Descriptor set layout.
2608     DescriptorSetLayoutBuilder dslBuilder;
2609     dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, 1u, shaderStages, nullptr);
2610     dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1u, shaderStages, nullptr); // Callee buffer.
2611     const auto descriptorSetLayout = dslBuilder.build(vkd, device);
2612 
2613     // Pipeline layout.
2614     const auto pipelineLayout = makePipelineLayout(vkd, device, descriptorSetLayout.get());
2615 
2616     // Descriptor pool and set.
2617     DescriptorPoolBuilder poolBuilder;
2618     poolBuilder.addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR);
2619     poolBuilder.addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
2620     const auto descriptorPool = poolBuilder.build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
2621     const auto descriptorSet  = makeDescriptorSet(vkd, device, descriptorPool.get(), descriptorSetLayout.get());
2622 
2623     // Update descriptor set.
2624     {
2625         const VkWriteDescriptorSetAccelerationStructureKHR writeASInfo = {
2626             VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR,
2627             nullptr,
2628             1u,
2629             topLevelAccelerationStructure.get()->getPtr(),
2630         };
2631 
2632         const auto ds                          = descriptorSet.get();
2633         const auto storageBufferDescriptorInfo = makeDescriptorBufferInfo(storageBuffer.get(), 0ull, VK_WHOLE_SIZE);
2634 
2635         DescriptorSetUpdateBuilder updateBuilder;
2636         updateBuilder.writeSingle(ds, DescriptorSetUpdateBuilder::Location::binding(0u),
2637                                   VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &writeASInfo);
2638         updateBuilder.writeSingle(ds, DescriptorSetUpdateBuilder::Location::binding(1u),
2639                                   VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &storageBufferDescriptorInfo);
2640         updateBuilder.update(vkd, device);
2641     }
2642 
2643     // Create raytracing pipeline and shader binding tables.
2644     const auto interfaceType = m_params.interfaceType;
2645     Move<VkPipeline> pipeline;
2646 
2647     de::MovePtr<BufferWithMemory> raygenShaderBindingTable;
2648     de::MovePtr<BufferWithMemory> missShaderBindingTable;
2649     de::MovePtr<BufferWithMemory> hitShaderBindingTable;
2650     de::MovePtr<BufferWithMemory> callableShaderBindingTable;
2651 
2652     VkStridedDeviceAddressRegionKHR raygenShaderBindingTableRegion   = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
2653     VkStridedDeviceAddressRegionKHR missShaderBindingTableRegion     = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
2654     VkStridedDeviceAddressRegionKHR hitShaderBindingTableRegion      = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
2655     VkStridedDeviceAddressRegionKHR callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
2656 
2657     {
2658         const auto rayTracingPipeline = de::newMovePtr<RayTracingPipeline>();
2659 
2660         // Every case uses a ray generation shader.
2661         rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR,
2662                                       createShaderModule(vkd, device, m_context.getBinaryCollection().get("rgen"), 0),
2663                                       0);
2664 
2665         if (interfaceType == InterfaceType::RAY_PAYLOAD)
2666         {
2667             rayTracingPipeline->addShader(
2668                 VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,
2669                 createShaderModule(vkd, device, m_context.getBinaryCollection().get("chit"), 0), 1);
2670         }
2671         else if (interfaceType == InterfaceType::CALLABLE_DATA ||
2672                  interfaceType == InterfaceType::SHADER_RECORD_BUFFER_RGEN)
2673         {
2674             rayTracingPipeline->addShader(
2675                 VK_SHADER_STAGE_CALLABLE_BIT_KHR,
2676                 createShaderModule(vkd, device, m_context.getBinaryCollection().get("call"), 0), 1);
2677         }
2678         else if (interfaceType == InterfaceType::HIT_ATTRIBUTES)
2679         {
2680             rayTracingPipeline->addShader(
2681                 VK_SHADER_STAGE_INTERSECTION_BIT_KHR,
2682                 createShaderModule(vkd, device, m_context.getBinaryCollection().get("rint"), 0), 1);
2683             rayTracingPipeline->addShader(
2684                 VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,
2685                 createShaderModule(vkd, device, m_context.getBinaryCollection().get("chit"), 0), 1);
2686         }
2687         else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_CALL)
2688         {
2689             rayTracingPipeline->addShader(
2690                 VK_SHADER_STAGE_CALLABLE_BIT_KHR,
2691                 createShaderModule(vkd, device, m_context.getBinaryCollection().get("call"), 0), 1);
2692             rayTracingPipeline->addShader(
2693                 VK_SHADER_STAGE_CALLABLE_BIT_KHR,
2694                 createShaderModule(vkd, device, m_context.getBinaryCollection().get("subcall"), 0), 2);
2695         }
2696         else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_MISS)
2697         {
2698             rayTracingPipeline->addShader(
2699                 VK_SHADER_STAGE_MISS_BIT_KHR,
2700                 createShaderModule(vkd, device, m_context.getBinaryCollection().get("miss"), 0), 1);
2701             rayTracingPipeline->addShader(
2702                 VK_SHADER_STAGE_CALLABLE_BIT_KHR,
2703                 createShaderModule(vkd, device, m_context.getBinaryCollection().get("call"), 0), 2);
2704         }
2705         else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_HIT)
2706         {
2707             rayTracingPipeline->addShader(
2708                 VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,
2709                 createShaderModule(vkd, device, m_context.getBinaryCollection().get("chit"), 0), 1);
2710             rayTracingPipeline->addShader(
2711                 VK_SHADER_STAGE_CALLABLE_BIT_KHR,
2712                 createShaderModule(vkd, device, m_context.getBinaryCollection().get("call"), 0), 2);
2713         }
2714         else
2715         {
2716             DE_ASSERT(false);
2717         }
2718 
2719         pipeline = rayTracingPipeline->createPipeline(vkd, device, pipelineLayout.get());
2720 
2721         if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_RGEN)
2722         {
2723             createSBTWithShaderRecord(vkd, device, alloc, pipeline.get(), rayTracingPipeline.get(),
2724                                       shaderGroupHandleSize, shaderGroupBaseAlignment, 0u, 1u, raygenShaderBindingTable,
2725                                       raygenShaderBindingTableRegion);
2726         }
2727         else
2728         {
2729             raygenShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
2730                 vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
2731             raygenShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(
2732                 getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0), shaderGroupHandleSize,
2733                 shaderGroupHandleSize);
2734         }
2735 
2736         if (interfaceType == InterfaceType::CALLABLE_DATA || interfaceType == InterfaceType::SHADER_RECORD_BUFFER_RGEN)
2737         {
2738             callableShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
2739                 vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
2740             callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(
2741                 getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize,
2742                 shaderGroupHandleSize);
2743         }
2744         else if (interfaceType == InterfaceType::RAY_PAYLOAD || interfaceType == InterfaceType::HIT_ATTRIBUTES)
2745         {
2746             hitShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
2747                 vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
2748             hitShaderBindingTableRegion =
2749                 makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0),
2750                                                   shaderGroupHandleSize, shaderGroupHandleSize);
2751         }
2752         else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_CALL)
2753         {
2754             createSBTWithShaderRecord(vkd, device, alloc, pipeline.get(), rayTracingPipeline.get(),
2755                                       shaderGroupHandleSize, shaderGroupBaseAlignment, 1u, 2u,
2756                                       callableShaderBindingTable, callableShaderBindingTableRegion);
2757         }
2758         else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_MISS)
2759         {
2760             createSBTWithShaderRecord(vkd, device, alloc, pipeline.get(), rayTracingPipeline.get(),
2761                                       shaderGroupHandleSize, shaderGroupBaseAlignment, 1u, 1u, missShaderBindingTable,
2762                                       missShaderBindingTableRegion);
2763 
2764             // Callable shader table.
2765             callableShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
2766                 vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
2767             callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(
2768                 getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize,
2769                 shaderGroupHandleSize);
2770         }
2771         else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_HIT)
2772         {
2773             createSBTWithShaderRecord(vkd, device, alloc, pipeline.get(), rayTracingPipeline.get(),
2774                                       shaderGroupHandleSize, shaderGroupBaseAlignment, 1u, 1u, hitShaderBindingTable,
2775                                       hitShaderBindingTableRegion);
2776 
2777             // Callable shader table.
2778             callableShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
2779                 vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
2780             callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(
2781                 getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize,
2782                 shaderGroupHandleSize);
2783         }
2784         else
2785         {
2786             DE_ASSERT(false);
2787         }
2788     }
2789 
2790     // Use ray tracing pipeline.
2791     vkd.cmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipeline.get());
2792     vkd.cmdBindDescriptorSets(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipelineLayout.get(), 0u, 1u,
2793                               &descriptorSet.get(), 0u, nullptr);
2794     vkd.cmdTraceRaysKHR(cmdBuffer, &raygenShaderBindingTableRegion, &missShaderBindingTableRegion,
2795                         &hitShaderBindingTableRegion, &callableShaderBindingTableRegion, 1u, 1u, 1u);
2796 
2797     // Synchronize output and callee buffers.
2798     const auto memBarrier = makeMemoryBarrier(VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
2799     vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR, VK_PIPELINE_STAGE_HOST_BIT, 0u, 1u,
2800                            &memBarrier, 0u, nullptr, 0u, nullptr);
2801 
2802     endCommandBuffer(vkd, cmdBuffer);
2803     submitCommandsAndWait(vkd, device, queue, cmdBuffer);
2804 
2805     // Verify storage buffer.
2806     invalidateAlloc(vkd, device, storageBufferAlloc);
2807     deMemcpy(storageBufferData.data(), storageBufferPtr, storageBufferSize);
2808 
2809     // These values must match what the shaders store.
2810     std::vector<uint32_t> expectedData;
2811     if (interfaceType == InterfaceType::RAY_PAYLOAD)
2812     {
2813         expectedData.push_back(103u);
2814         expectedData.push_back(100u);
2815     }
2816     else if (interfaceType == InterfaceType::CALLABLE_DATA)
2817     {
2818         expectedData.push_back(200u);
2819         expectedData.push_back(100u);
2820     }
2821     else if (interfaceType == InterfaceType::HIT_ATTRIBUTES)
2822     {
2823         expectedData.push_back(300u);
2824         expectedData.push_back(315u);
2825         expectedData.push_back(330u);
2826     }
2827     else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_RGEN)
2828     {
2829         expectedData.push_back(402u);
2830         expectedData.push_back(450u);
2831     }
2832     else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_CALL)
2833     {
2834         expectedData.push_back(806u);
2835         expectedData.push_back(403u);
2836         expectedData.push_back(450u);
2837     }
2838     else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_MISS ||
2839              interfaceType == InterfaceType::SHADER_RECORD_BUFFER_HIT)
2840     {
2841         expectedData.push_back(1200u);
2842         expectedData.push_back(400u);
2843         expectedData.push_back(490u);
2844     }
2845     else
2846     {
2847         DE_ASSERT(false);
2848     }
2849 
2850     size_t pos;
2851     for (pos = 0u; pos < expectedData.size(); ++pos)
2852     {
2853         const auto &stored   = storageBufferData.at(pos);
2854         const auto &expected = expectedData.at(pos);
2855         if (stored != expected)
2856         {
2857             std::ostringstream msg;
2858             msg << "Unexpected output value found at position " << pos << " (expected " << expected << " but got "
2859                 << stored << ")";
2860             return tcu::TestStatus::fail(msg.str());
2861         }
2862     }
2863 
2864     // Expect zeros in unused positions, as filled on the host.
2865     for (; pos < storageBufferData.size(); ++pos)
2866     {
2867         const auto &stored = storageBufferData.at(pos);
2868         if (stored != 0u)
2869         {
2870             std::ostringstream msg;
2871             msg << "Unexpected output value found at position " << pos << " (expected 0 but got " << stored << ")";
2872             return tcu::TestStatus::fail(msg.str());
2873         }
2874     }
2875 
2876     return tcu::TestStatus::pass("Pass");
2877 }
2878 
2879 } // anonymous namespace
2880 
createDataSpillTests(tcu::TestContext & testCtx)2881 tcu::TestCaseGroup *createDataSpillTests(tcu::TestContext &testCtx)
2882 {
2883     // Ray tracing tests for data spilling and unspilling around shader calls
2884     de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "data_spill"));
2885 
2886     struct
2887     {
2888         CallType callType;
2889         const char *name;
2890     } callTypes[] = {
2891         {CallType::EXECUTE_CALLABLE, "execute_callable"},
2892         {CallType::TRACE_RAY, "trace_ray"},
2893         {CallType::REPORT_INTERSECTION, "report_intersection"},
2894     };
2895 
2896     struct
2897     {
2898         DataType dataType;
2899         const char *name;
2900     } dataTypes[] = {
2901         {DataType::INT32, "int32"},
2902         {DataType::UINT32, "uint32"},
2903         {DataType::INT64, "int64"},
2904         {DataType::UINT64, "uint64"},
2905         {DataType::INT16, "int16"},
2906         {DataType::UINT16, "uint16"},
2907         {DataType::INT8, "int8"},
2908         {DataType::UINT8, "uint8"},
2909         {DataType::FLOAT32, "float32"},
2910         {DataType::FLOAT64, "float64"},
2911         {DataType::FLOAT16, "float16"},
2912         {DataType::STRUCT, "struct"},
2913         {DataType::SAMPLER, "sampler"},
2914         {DataType::IMAGE, "image"},
2915         {DataType::SAMPLED_IMAGE, "combined"},
2916         {DataType::PTR_IMAGE, "ptr_image"},
2917         {DataType::PTR_SAMPLER, "ptr_sampler"},
2918         {DataType::PTR_SAMPLED_IMAGE, "ptr_combined"},
2919         {DataType::PTR_TEXEL, "ptr_texel"},
2920         {DataType::OP_NULL, "op_null"},
2921         {DataType::OP_UNDEF, "op_undef"},
2922     };
2923 
2924     struct
2925     {
2926         VectorType vectorType;
2927         const char *prefix;
2928     } vectorTypes[] = {
2929         {VectorType::SCALAR, ""}, {VectorType::V2, "v2"}, {VectorType::V3, "v3"},
2930         {VectorType::V4, "v4"},   {VectorType::A5, "a5"},
2931     };
2932 
2933     for (int callTypeIdx = 0; callTypeIdx < DE_LENGTH_OF_ARRAY(callTypes); ++callTypeIdx)
2934     {
2935         const auto &entryCallTypes = callTypes[callTypeIdx];
2936 
2937         de::MovePtr<tcu::TestCaseGroup> callTypeGroup(new tcu::TestCaseGroup(testCtx, entryCallTypes.name));
2938         for (int dataTypeIdx = 0; dataTypeIdx < DE_LENGTH_OF_ARRAY(dataTypes); ++dataTypeIdx)
2939         {
2940             const auto &entryDataTypes = dataTypes[dataTypeIdx];
2941 
2942             for (int vectorTypeIdx = 0; vectorTypeIdx < DE_LENGTH_OF_ARRAY(vectorTypes); ++vectorTypeIdx)
2943             {
2944                 const auto &entryVectorTypes = vectorTypes[vectorTypeIdx];
2945 
2946                 if ((samplersNeeded(entryDataTypes.dataType) || storageImageNeeded(entryDataTypes.dataType) ||
2947                      entryDataTypes.dataType == DataType::STRUCT || entryDataTypes.dataType == DataType::OP_NULL ||
2948                      entryDataTypes.dataType == DataType::OP_UNDEF) &&
2949                     entryVectorTypes.vectorType != VectorType::SCALAR)
2950                 {
2951                     continue;
2952                 }
2953 
2954                 DataSpillTestCase::TestParams params;
2955                 params.callType   = entryCallTypes.callType;
2956                 params.dataType   = entryDataTypes.dataType;
2957                 params.vectorType = entryVectorTypes.vectorType;
2958 
2959                 const auto testName = std::string(entryVectorTypes.prefix) + entryDataTypes.name;
2960 
2961                 callTypeGroup->addChild(new DataSpillTestCase(testCtx, testName, params));
2962             }
2963         }
2964 
2965         group->addChild(callTypeGroup.release());
2966     }
2967 
2968     // Pipeline interface tests.
2969     // Test data spilling and unspilling of pipeline interface variables
2970     de::MovePtr<tcu::TestCaseGroup> pipelineInterfaceGroup(new tcu::TestCaseGroup(testCtx, "pipeline_interface"));
2971 
2972     struct
2973     {
2974         InterfaceType interfaceType;
2975         const char *name;
2976     } interfaceTypes[] = {
2977         {InterfaceType::RAY_PAYLOAD, "ray_payload"},
2978         {InterfaceType::CALLABLE_DATA, "callable_data"},
2979         {InterfaceType::HIT_ATTRIBUTES, "hit_attributes"},
2980         {InterfaceType::SHADER_RECORD_BUFFER_RGEN, "shader_record_buffer_rgen"},
2981         {InterfaceType::SHADER_RECORD_BUFFER_CALL, "shader_record_buffer_call"},
2982         {InterfaceType::SHADER_RECORD_BUFFER_MISS, "shader_record_buffer_miss"},
2983         {InterfaceType::SHADER_RECORD_BUFFER_HIT, "shader_record_buffer_hit"},
2984     };
2985 
2986     for (int idx = 0; idx < DE_LENGTH_OF_ARRAY(interfaceTypes); ++idx)
2987     {
2988         const auto &entry = interfaceTypes[idx];
2989         DataSpillPipelineInterfaceTestCase::TestParams params;
2990 
2991         params.interfaceType = entry.interfaceType;
2992 
2993         pipelineInterfaceGroup->addChild(new DataSpillPipelineInterfaceTestCase(testCtx, entry.name, params));
2994     }
2995 
2996     group->addChild(pipelineInterfaceGroup.release());
2997 
2998     return group.release();
2999 }
3000 
3001 } // namespace RayTracing
3002 } // namespace vkt
3003