1 /*-------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2020 The Khronos Group Inc.
6 * Copyright (c) 2020 Valve Corporation.
7 *
8 * Licensed under the Apache License, Version 2.0 (the "License");
9 * you may not use this file except in compliance with the License.
10 * You may obtain a copy of the License at
11 *
12 * http://www.apache.org/licenses/LICENSE-2.0
13 *
14 * Unless required by applicable law or agreed to in writing, software
15 * distributed under the License is distributed on an "AS IS" BASIS,
16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 * See the License for the specific language governing permissions and
18 * limitations under the License.
19 *
20 *//*!
21 * \file
22 * \brief Ray Tracing Data Spill tests
23 *//*--------------------------------------------------------------------*/
24 #include "vktRayTracingDataSpillTests.hpp"
25 #include "vktTestCase.hpp"
26
27 #include "vkRayTracingUtil.hpp"
28 #include "vkObjUtil.hpp"
29 #include "vkBufferWithMemory.hpp"
30 #include "vkImageWithMemory.hpp"
31 #include "vkBuilderUtil.hpp"
32 #include "vkCmdUtil.hpp"
33 #include "vkTypeUtil.hpp"
34 #include "vkBarrierUtil.hpp"
35
36 #include "tcuStringTemplate.hpp"
37 #include "tcuFloat.hpp"
38
39 #include "deUniquePtr.hpp"
40 #include "deSTLUtil.hpp"
41
42 #include <sstream>
43 #include <string>
44 #include <map>
45 #include <vector>
46 #include <array>
47 #include <utility>
48
49 using namespace vk;
50
51 namespace vkt
52 {
53 namespace RayTracing
54 {
55
56 namespace
57 {
58
59 // The type of shader call that will be used.
60 enum class CallType
61 {
62 TRACE_RAY = 0,
63 EXECUTE_CALLABLE,
64 REPORT_INTERSECTION,
65 };
66
67 // The type of data that will be checked.
68 enum class DataType
69 {
70 // These can be made an array or vector.
71 INT32 = 0,
72 UINT32,
73 INT64,
74 UINT64,
75 INT16,
76 UINT16,
77 INT8,
78 UINT8,
79 FLOAT32,
80 FLOAT64,
81 FLOAT16,
82
83 // These are standalone, so the vector type should be scalar.
84 STRUCT,
85 IMAGE,
86 SAMPLER,
87 SAMPLED_IMAGE,
88 PTR_IMAGE,
89 PTR_SAMPLER,
90 PTR_SAMPLED_IMAGE,
91 PTR_TEXEL,
92 OP_NULL,
93 OP_UNDEF,
94 };
95
96 // The type of vector in use.
97 enum class VectorType
98 {
99 SCALAR = 1,
100 V2 = 2,
101 V3 = 3,
102 V4 = 4,
103 A5 = 5,
104 };
105
106 struct InputStruct
107 {
108 uint32_t uintPart;
109 float floatPart;
110 };
111
112 constexpr auto kImageFormat = VK_FORMAT_R32_UINT;
113 const auto kImageExtent = makeExtent3D(1u, 1u, 1u);
114
115 // For samplers.
116 const VkImageUsageFlags kSampledImageUsage = (VK_IMAGE_USAGE_TRANSFER_DST_BIT | VK_IMAGE_USAGE_SAMPLED_BIT);
117 constexpr size_t kNumImages = 4u;
118 constexpr size_t kNumSamplers = 4u;
119 constexpr size_t kNumCombined = 2u;
120 constexpr size_t kNumAloneImages = kNumImages - kNumCombined;
121 constexpr size_t kNumAloneSamplers = kNumSamplers - kNumCombined;
122
123 // For storage images.
124 const VkImageUsageFlags kStorageImageUsage = (VK_IMAGE_USAGE_TRANSFER_DST_BIT | VK_IMAGE_USAGE_STORAGE_BIT);
125
126 // For the pipeline interface tests.
127 constexpr size_t kNumStorageValues = 6u;
128 constexpr uint32_t kShaderRecordSize = sizeof(tcu::UVec4);
129
130 // Get the effective vector length in memory.
getEffectiveVectorLength(VectorType vectorType)131 size_t getEffectiveVectorLength(VectorType vectorType)
132 {
133 return ((vectorType == VectorType::V3) ? static_cast<size_t>(4) : static_cast<size_t>(vectorType));
134 }
135
136 // Get the corresponding element size.
getElementSize(DataType dataType,VectorType vectorType)137 VkDeviceSize getElementSize(DataType dataType, VectorType vectorType)
138 {
139 const size_t length = getEffectiveVectorLength(vectorType);
140 size_t dataSize = 0u;
141
142 switch (dataType)
143 {
144 case DataType::INT32:
145 dataSize = sizeof(int32_t);
146 break;
147 case DataType::UINT32:
148 dataSize = sizeof(uint32_t);
149 break;
150 case DataType::INT64:
151 dataSize = sizeof(int64_t);
152 break;
153 case DataType::UINT64:
154 dataSize = sizeof(uint64_t);
155 break;
156 case DataType::INT16:
157 dataSize = sizeof(int16_t);
158 break;
159 case DataType::UINT16:
160 dataSize = sizeof(uint16_t);
161 break;
162 case DataType::INT8:
163 dataSize = sizeof(int8_t);
164 break;
165 case DataType::UINT8:
166 dataSize = sizeof(uint8_t);
167 break;
168 case DataType::FLOAT32:
169 dataSize = sizeof(tcu::Float32);
170 break;
171 case DataType::FLOAT64:
172 dataSize = sizeof(tcu::Float64);
173 break;
174 case DataType::FLOAT16:
175 dataSize = sizeof(tcu::Float16);
176 break;
177 case DataType::STRUCT:
178 dataSize = sizeof(InputStruct);
179 break;
180 case DataType::IMAGE: // fallthrough.
181 case DataType::SAMPLER: // fallthrough.
182 case DataType::SAMPLED_IMAGE: // fallthrough.
183 case DataType::PTR_IMAGE: // fallthrough.
184 case DataType::PTR_SAMPLER: // fallthrough.
185 case DataType::PTR_SAMPLED_IMAGE: // fallthrough.
186 dataSize = sizeof(tcu::Float32);
187 break;
188 case DataType::PTR_TEXEL:
189 dataSize = sizeof(int32_t);
190 break;
191 case DataType::OP_NULL: // fallthrough.
192 case DataType::OP_UNDEF: // fallthrough.
193 dataSize = sizeof(uint32_t);
194 break;
195 default:
196 DE_ASSERT(false);
197 break;
198 }
199
200 return static_cast<VkDeviceSize>(dataSize * length);
201 }
202
203 // Proper stage for generating default geometry.
getShaderStageForGeometry(CallType type_)204 VkShaderStageFlagBits getShaderStageForGeometry(CallType type_)
205 {
206 VkShaderStageFlagBits bits = VK_SHADER_STAGE_FLAG_BITS_MAX_ENUM;
207
208 switch (type_)
209 {
210 case CallType::TRACE_RAY:
211 bits = VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
212 break;
213 case CallType::EXECUTE_CALLABLE:
214 bits = VK_SHADER_STAGE_CALLABLE_BIT_KHR;
215 break;
216 case CallType::REPORT_INTERSECTION:
217 bits = VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
218 break;
219 default:
220 DE_ASSERT(false);
221 break;
222 }
223
224 DE_ASSERT(bits != VK_SHADER_STAGE_FLAG_BITS_MAX_ENUM);
225 return bits;
226 }
227
getShaderStages(CallType type_)228 VkShaderStageFlags getShaderStages(CallType type_)
229 {
230 VkShaderStageFlags flags = VK_SHADER_STAGE_RAYGEN_BIT_KHR;
231
232 switch (type_)
233 {
234 case CallType::EXECUTE_CALLABLE:
235 flags |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
236 break;
237 case CallType::TRACE_RAY:
238 flags |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
239 break;
240 case CallType::REPORT_INTERSECTION:
241 flags |= VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
242 flags |= VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
243 break;
244 default:
245 DE_ASSERT(false);
246 break;
247 }
248
249 return flags;
250 }
251
252 // Some test types need additional descriptors with samplers, images and combined image samplers.
samplersNeeded(DataType dataType)253 bool samplersNeeded(DataType dataType)
254 {
255 bool needed = false;
256
257 switch (dataType)
258 {
259 case DataType::IMAGE:
260 case DataType::SAMPLER:
261 case DataType::SAMPLED_IMAGE:
262 case DataType::PTR_IMAGE:
263 case DataType::PTR_SAMPLER:
264 case DataType::PTR_SAMPLED_IMAGE:
265 needed = true;
266 break;
267 default:
268 break;
269 }
270
271 return needed;
272 }
273
274 // Some test types need an additional descriptor with a storage image.
storageImageNeeded(DataType dataType)275 bool storageImageNeeded(DataType dataType)
276 {
277 return (dataType == DataType::PTR_TEXEL);
278 }
279
280 // Returns two strings:
281 // .first is an optional GLSL additional type declaration (for structs, basically).
282 // .second is the value declaration inside the input block.
getGLSLInputValDecl(DataType dataType,VectorType vectorType)283 std::pair<std::string, std::string> getGLSLInputValDecl(DataType dataType, VectorType vectorType)
284 {
285 using TypePair = std::pair<DataType, VectorType>;
286 using TypeMap = std::map<TypePair, std::string>;
287
288 const std::string varName = "val";
289 const auto dataTypeIdx = static_cast<int>(dataType);
290
291 if (dataTypeIdx >= static_cast<int>(DataType::INT32) && dataTypeIdx <= static_cast<int>(DataType::FLOAT16))
292 {
293 // Note: A5 uses the same type as the scalar version. The array suffix will be added below.
294 const TypeMap map = {
295 std::make_pair(std::make_pair(DataType::INT32, VectorType::SCALAR), "int32_t"),
296 std::make_pair(std::make_pair(DataType::INT32, VectorType::V2), "i32vec2"),
297 std::make_pair(std::make_pair(DataType::INT32, VectorType::V3), "i32vec3"),
298 std::make_pair(std::make_pair(DataType::INT32, VectorType::V4), "i32vec4"),
299 std::make_pair(std::make_pair(DataType::INT32, VectorType::A5), "int32_t"),
300 std::make_pair(std::make_pair(DataType::UINT32, VectorType::SCALAR), "uint32_t"),
301 std::make_pair(std::make_pair(DataType::UINT32, VectorType::V2), "u32vec2"),
302 std::make_pair(std::make_pair(DataType::UINT32, VectorType::V3), "u32vec3"),
303 std::make_pair(std::make_pair(DataType::UINT32, VectorType::V4), "u32vec4"),
304 std::make_pair(std::make_pair(DataType::UINT32, VectorType::A5), "uint32_t"),
305 std::make_pair(std::make_pair(DataType::INT64, VectorType::SCALAR), "int64_t"),
306 std::make_pair(std::make_pair(DataType::INT64, VectorType::V2), "i64vec2"),
307 std::make_pair(std::make_pair(DataType::INT64, VectorType::V3), "i64vec3"),
308 std::make_pair(std::make_pair(DataType::INT64, VectorType::V4), "i64vec4"),
309 std::make_pair(std::make_pair(DataType::INT64, VectorType::A5), "int64_t"),
310 std::make_pair(std::make_pair(DataType::UINT64, VectorType::SCALAR), "uint64_t"),
311 std::make_pair(std::make_pair(DataType::UINT64, VectorType::V2), "u64vec2"),
312 std::make_pair(std::make_pair(DataType::UINT64, VectorType::V3), "u64vec3"),
313 std::make_pair(std::make_pair(DataType::UINT64, VectorType::V4), "u64vec4"),
314 std::make_pair(std::make_pair(DataType::UINT64, VectorType::A5), "uint64_t"),
315 std::make_pair(std::make_pair(DataType::INT16, VectorType::SCALAR), "int16_t"),
316 std::make_pair(std::make_pair(DataType::INT16, VectorType::V2), "i16vec2"),
317 std::make_pair(std::make_pair(DataType::INT16, VectorType::V3), "i16vec3"),
318 std::make_pair(std::make_pair(DataType::INT16, VectorType::V4), "i16vec4"),
319 std::make_pair(std::make_pair(DataType::INT16, VectorType::A5), "int16_t"),
320 std::make_pair(std::make_pair(DataType::UINT16, VectorType::SCALAR), "uint16_t"),
321 std::make_pair(std::make_pair(DataType::UINT16, VectorType::V2), "u16vec2"),
322 std::make_pair(std::make_pair(DataType::UINT16, VectorType::V3), "u16vec3"),
323 std::make_pair(std::make_pair(DataType::UINT16, VectorType::V4), "u16vec4"),
324 std::make_pair(std::make_pair(DataType::UINT16, VectorType::A5), "uint16_t"),
325 std::make_pair(std::make_pair(DataType::INT8, VectorType::SCALAR), "int8_t"),
326 std::make_pair(std::make_pair(DataType::INT8, VectorType::V2), "i8vec2"),
327 std::make_pair(std::make_pair(DataType::INT8, VectorType::V3), "i8vec3"),
328 std::make_pair(std::make_pair(DataType::INT8, VectorType::V4), "i8vec4"),
329 std::make_pair(std::make_pair(DataType::INT8, VectorType::A5), "int8_t"),
330 std::make_pair(std::make_pair(DataType::UINT8, VectorType::SCALAR), "uint8_t"),
331 std::make_pair(std::make_pair(DataType::UINT8, VectorType::V2), "u8vec2"),
332 std::make_pair(std::make_pair(DataType::UINT8, VectorType::V3), "u8vec3"),
333 std::make_pair(std::make_pair(DataType::UINT8, VectorType::V4), "u8vec4"),
334 std::make_pair(std::make_pair(DataType::UINT8, VectorType::A5), "uint8_t"),
335 std::make_pair(std::make_pair(DataType::FLOAT32, VectorType::SCALAR), "float32_t"),
336 std::make_pair(std::make_pair(DataType::FLOAT32, VectorType::V2), "f32vec2"),
337 std::make_pair(std::make_pair(DataType::FLOAT32, VectorType::V3), "f32vec3"),
338 std::make_pair(std::make_pair(DataType::FLOAT32, VectorType::V4), "f32vec4"),
339 std::make_pair(std::make_pair(DataType::FLOAT32, VectorType::A5), "float32_t"),
340 std::make_pair(std::make_pair(DataType::FLOAT64, VectorType::SCALAR), "float64_t"),
341 std::make_pair(std::make_pair(DataType::FLOAT64, VectorType::V2), "f64vec2"),
342 std::make_pair(std::make_pair(DataType::FLOAT64, VectorType::V3), "f64vec3"),
343 std::make_pair(std::make_pair(DataType::FLOAT64, VectorType::V4), "f64vec4"),
344 std::make_pair(std::make_pair(DataType::FLOAT64, VectorType::A5), "float64_t"),
345 std::make_pair(std::make_pair(DataType::FLOAT16, VectorType::SCALAR), "float16_t"),
346 std::make_pair(std::make_pair(DataType::FLOAT16, VectorType::V2), "f16vec2"),
347 std::make_pair(std::make_pair(DataType::FLOAT16, VectorType::V3), "f16vec3"),
348 std::make_pair(std::make_pair(DataType::FLOAT16, VectorType::V4), "f16vec4"),
349 std::make_pair(std::make_pair(DataType::FLOAT16, VectorType::A5), "float16_t"),
350 };
351
352 const auto key = std::make_pair(dataType, vectorType);
353 const auto found = map.find(key);
354
355 DE_ASSERT(found != end(map));
356
357 const auto baseType = found->second;
358 const std::string decl = baseType + " " + varName + ((vectorType == VectorType::A5) ? "[5]" : "") + ";";
359
360 return std::make_pair(std::string(), decl);
361 }
362 else if (dataType == DataType::STRUCT)
363 {
364 return std::make_pair(std::string("struct InputStruct { uint val1; float val2; };\n"),
365 std::string("InputStruct val;"));
366 }
367 else if (samplersNeeded(dataType))
368 {
369 return std::make_pair(std::string(), std::string("float val;"));
370 }
371 else if (storageImageNeeded(dataType))
372 {
373 return std::make_pair(std::string(), std::string("int val;"));
374 }
375 else if (dataType == DataType::OP_NULL || dataType == DataType::OP_UNDEF)
376 {
377 return std::make_pair(std::string(), std::string("uint val;"));
378 }
379
380 // Unreachable.
381 DE_ASSERT(false);
382 return std::make_pair(std::string(), std::string());
383 }
384
385 class DataSpillTestCase : public vkt::TestCase
386 {
387 public:
388 struct TestParams
389 {
390 CallType callType;
391 DataType dataType;
392 VectorType vectorType;
393 };
394
395 DataSpillTestCase(tcu::TestContext &testCtx, const std::string &name, const TestParams &testParams);
~DataSpillTestCase(void)396 virtual ~DataSpillTestCase(void)
397 {
398 }
399
400 virtual void initPrograms(vk::SourceCollections &programCollection) const;
401 virtual TestInstance *createInstance(Context &context) const;
402 virtual void checkSupport(Context &context) const;
403
404 private:
405 TestParams m_params;
406 };
407
408 class DataSpillTestInstance : public vkt::TestInstance
409 {
410 public:
411 using TestParams = DataSpillTestCase::TestParams;
412
413 DataSpillTestInstance(Context &context, const TestParams &testParams);
~DataSpillTestInstance(void)414 virtual ~DataSpillTestInstance(void)
415 {
416 }
417
418 virtual tcu::TestStatus iterate(void);
419
420 private:
421 TestParams m_params;
422 };
423
DataSpillTestCase(tcu::TestContext & testCtx,const std::string & name,const TestParams & testParams)424 DataSpillTestCase::DataSpillTestCase(tcu::TestContext &testCtx, const std::string &name, const TestParams &testParams)
425 : vkt::TestCase(testCtx, name)
426 , m_params(testParams)
427 {
428 switch (m_params.dataType)
429 {
430 case DataType::STRUCT:
431 case DataType::IMAGE:
432 case DataType::SAMPLER:
433 case DataType::SAMPLED_IMAGE:
434 case DataType::PTR_IMAGE:
435 case DataType::PTR_SAMPLER:
436 case DataType::PTR_SAMPLED_IMAGE:
437 case DataType::PTR_TEXEL:
438 case DataType::OP_NULL:
439 case DataType::OP_UNDEF:
440 DE_ASSERT(m_params.vectorType == VectorType::SCALAR);
441 break;
442 default:
443 break;
444 }
445
446 // The code assumes at most one of these is needed.
447 DE_ASSERT(!(samplersNeeded(m_params.dataType) && storageImageNeeded(m_params.dataType)));
448 }
449
createInstance(Context & context) const450 TestInstance *DataSpillTestCase::createInstance(Context &context) const
451 {
452 return new DataSpillTestInstance(context, m_params);
453 }
454
DataSpillTestInstance(Context & context,const TestParams & testParams)455 DataSpillTestInstance::DataSpillTestInstance(Context &context, const TestParams &testParams)
456 : vkt::TestInstance(context)
457 , m_params(testParams)
458 {
459 }
460
461 // General checks for all tests.
commonCheckSupport(Context & context)462 void commonCheckSupport(Context &context)
463 {
464 context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
465 context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
466
467 const auto &rtFeatures = context.getRayTracingPipelineFeatures();
468 if (!rtFeatures.rayTracingPipeline)
469 TCU_THROW(NotSupportedError, "Ray Tracing pipelines not supported");
470
471 const auto &asFeatures = context.getAccelerationStructureFeatures();
472 if (!asFeatures.accelerationStructure)
473 TCU_FAIL("VK_KHR_acceleration_structure supported without accelerationStructure support");
474 }
475
checkSupport(Context & context) const476 void DataSpillTestCase::checkSupport(Context &context) const
477 {
478 // General checks first.
479 commonCheckSupport(context);
480
481 const auto &features = context.getDeviceFeatures();
482 const auto &featuresStorage16 = context.get16BitStorageFeatures();
483 const auto &featuresF16I8 = context.getShaderFloat16Int8Features();
484 const auto &featuresStorage8 = context.get8BitStorageFeatures();
485
486 if (m_params.dataType == DataType::INT64 || m_params.dataType == DataType::UINT64)
487 {
488 if (!features.shaderInt64)
489 TCU_THROW(NotSupportedError, "64-bit integers not supported");
490 }
491 else if (m_params.dataType == DataType::INT16 || m_params.dataType == DataType::UINT16)
492 {
493 context.requireDeviceFunctionality("VK_KHR_16bit_storage");
494
495 if (!features.shaderInt16)
496 TCU_THROW(NotSupportedError, "16-bit integers not supported");
497
498 if (!featuresStorage16.storageBuffer16BitAccess)
499 TCU_THROW(NotSupportedError, "16-bit storage buffer access not supported");
500 }
501 else if (m_params.dataType == DataType::INT8 || m_params.dataType == DataType::UINT8)
502 {
503 context.requireDeviceFunctionality("VK_KHR_shader_float16_int8");
504 context.requireDeviceFunctionality("VK_KHR_8bit_storage");
505
506 if (!featuresF16I8.shaderInt8)
507 TCU_THROW(NotSupportedError, "8-bit integers not supported");
508
509 if (!featuresStorage8.storageBuffer8BitAccess)
510 TCU_THROW(NotSupportedError, "8-bit storage buffer access not supported");
511 }
512 else if (m_params.dataType == DataType::FLOAT64)
513 {
514 if (!features.shaderFloat64)
515 TCU_THROW(NotSupportedError, "64-bit floats not supported");
516 }
517 else if (m_params.dataType == DataType::FLOAT16)
518 {
519 context.requireDeviceFunctionality("VK_KHR_shader_float16_int8");
520 context.requireDeviceFunctionality("VK_KHR_16bit_storage");
521
522 if (!featuresF16I8.shaderFloat16)
523 TCU_THROW(NotSupportedError, "16-bit floats not supported");
524
525 if (!featuresStorage16.storageBuffer16BitAccess)
526 TCU_THROW(NotSupportedError, "16-bit storage buffer access not supported");
527 }
528 else if (samplersNeeded(m_params.dataType))
529 {
530 context.requireDeviceFunctionality("VK_EXT_descriptor_indexing");
531 const auto indexingFeatures = context.getDescriptorIndexingFeatures();
532 if (!indexingFeatures.shaderSampledImageArrayNonUniformIndexing)
533 TCU_THROW(NotSupportedError, "No support for non-uniform sampled image arrays");
534 }
535 }
536
initPrograms(vk::SourceCollections & programCollection) const537 void DataSpillTestCase::initPrograms(vk::SourceCollections &programCollection) const
538 {
539 const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
540 const vk::SpirVAsmBuildOptions spvBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, true);
541
542 std::ostringstream spvTemplateStream;
543
544 // This SPIR-V template will be used to generate shaders for different
545 // stages (raygen, callable, etc). The basic mechanism uses 3 SSBOs: one
546 // used strictly as an input, one to write the check result, and one to
547 // verify the shader call has taken place. The latter two SSBOs contain just
548 // a single uint, but the input SSBO typically contains other type of data
549 // that will be filled from the test instance with predetermined values. The
550 // shader will expect this data to have specific values that can be combined
551 // some way to give an expected result (e.g. by adding the 4 components if
552 // it's a vec4). This result will be used in the shader call to make sure
553 // input values are read *before* the call. After the shader call has taken
554 // place, the shader will attempt to read the input buffer again and verify
555 // the value is still correct and matches the previous one. If the result
556 // matches, it will write a confirmation value in the check buffer. In the
557 // mean time, the callee will write a confirmation value in the callee
558 // buffer to verify the shader call took place.
559 //
560 // Some test variants use samplers, images or sampled images. These need
561 // additional bindings of different types and the interesting value is
562 // typically placed in the image instead of the input buffer, while the
563 // input buffer is used for sampling coordinates instead.
564 //
565 // Some important SPIR-V template variables:
566 //
567 // - INPUT_BUFFER_VALUE_TYPE will contain the type of input buffer data.
568 // - CALC_ZERO_FOR_CALLABLE is expected to contain instructions that will
569 // calculate a value of zero to be used in the shader call instruction.
570 // This value should be derived from the input data.
571 // - CALL_STATEMENTS will contain the shader call instructions.
572 // - CALC_EQUAL_STATEMENT is expected to contain instructions that will
573 // set %equal to true as a %bool if the before- and after- data match.
574 //
575 // - %input_val_ptr contains the pointer to the input value.
576 // - %input_val_before contains the value read before the call.
577 // - %input_val_after contains the value read after the call.
578
579 spvTemplateStream
580 << " OpCapability RayTracingKHR\n"
581 << "${EXTRA_CAPABILITIES}"
582 << " OpExtension \"SPV_KHR_ray_tracing\"\n"
583 << "${EXTRA_EXTENSIONS}"
584 << " OpMemoryModel Logical GLSL450\n"
585 << " OpEntryPoint ${ENTRY_POINT} %main \"main\" %topLevelAS %calleeBuffer "
586 "%outputBuffer %inputBuffer${MAIN_INTERFACE_EXTRAS}\n"
587 << "${INTERFACE_DECORATIONS}"
588 << " OpMemberDecorate %InputBlock 0 Offset 0\n"
589 << " OpDecorate %InputBlock Block\n"
590 << " OpDecorate %inputBuffer DescriptorSet 0\n"
591 << " OpDecorate %inputBuffer Binding 3\n"
592 << " OpMemberDecorate %OutputBlock 0 Offset 0\n"
593 << " OpDecorate %OutputBlock Block\n"
594 << " OpDecorate %outputBuffer DescriptorSet 0\n"
595 << " OpDecorate %outputBuffer Binding 2\n"
596 << " OpMemberDecorate %CalleeBlock 0 Offset 0\n"
597 << " OpDecorate %CalleeBlock Block\n"
598 << " OpDecorate %calleeBuffer DescriptorSet 0\n"
599 << " OpDecorate %calleeBuffer Binding 1\n"
600 << " OpDecorate %topLevelAS DescriptorSet 0\n"
601 << " OpDecorate %topLevelAS Binding 0\n"
602 << "${EXTRA_BINDINGS}"
603 << " %void = OpTypeVoid\n"
604 << " %void_func = OpTypeFunction %void\n"
605 << " %int = OpTypeInt 32 1\n"
606 << " %uint = OpTypeInt 32 0\n"
607 << " %int_0 = OpConstant %int 0\n"
608 << " %uint_0 = OpConstant %uint 0\n"
609 << " %uint_1 = OpConstant %uint 1\n"
610 << " %uint_2 = OpConstant %uint 2\n"
611 << " %uint_3 = OpConstant %uint 3\n"
612 << " %uint_4 = OpConstant %uint 4\n"
613 << " %uint_5 = OpConstant %uint 5\n"
614 << " %uint_255 = OpConstant %uint 255\n"
615 << " %bool = OpTypeBool\n"
616 << " %float = OpTypeFloat 32\n"
617 << " %float_0 = OpConstant %float 0\n"
618 << " %float_1 = OpConstant %float 1\n"
619 << " %float_9 = OpConstant %float 9\n"
620 << " %float_0_5 = OpConstant %float 0.5\n"
621 << " %float_n1 = OpConstant %float -1\n"
622 << " %v3float = OpTypeVector %float 3\n"
623 << " %origin_const = OpConstantComposite %v3float %float_0_5 %float_0_5 %float_0\n"
624 << " %direction_const = OpConstantComposite %v3float %float_0 %float_0 %float_n1\n"
625 << "${EXTRA_TYPES_AND_CONSTANTS}"
626 << " %data_func_ptr = OpTypePointer Function ${INPUT_BUFFER_VALUE_TYPE}\n"
627 << "${INTERFACE_TYPES_AND_VARIABLES}"
628 << " %InputBlock = OpTypeStruct ${INPUT_BUFFER_VALUE_TYPE}\n"
629 << " %_ptr_StorageBuffer_InputBlock = OpTypePointer StorageBuffer %InputBlock\n"
630 << " %inputBuffer = OpVariable %_ptr_StorageBuffer_InputBlock StorageBuffer\n"
631 << " %data_storagebuffer_ptr = OpTypePointer StorageBuffer ${INPUT_BUFFER_VALUE_TYPE}\n"
632 << " %OutputBlock = OpTypeStruct %uint\n"
633 << "%_ptr_StorageBuffer_OutputBlock = OpTypePointer StorageBuffer %OutputBlock\n"
634 << " %outputBuffer = OpVariable %_ptr_StorageBuffer_OutputBlock StorageBuffer\n"
635 << " %_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint\n"
636 << " %CalleeBlock = OpTypeStruct %uint\n"
637 << "%_ptr_StorageBuffer_CalleeBlock = OpTypePointer StorageBuffer %CalleeBlock\n"
638 << " %calleeBuffer = OpVariable %_ptr_StorageBuffer_CalleeBlock StorageBuffer\n"
639 << " %as_type = OpTypeAccelerationStructureKHR\n"
640 << " %as_uniformconstant_ptr = OpTypePointer UniformConstant %as_type\n"
641 << " %topLevelAS = OpVariable %as_uniformconstant_ptr UniformConstant\n"
642 << "${EXTRA_BINDING_VARIABLES}"
643 << " %main = OpFunction %void None %void_func\n"
644 << " %main_label = OpLabel\n"
645 << "${EXTRA_FUNCTION_VARIABLES}"
646 << " %input_val_ptr = OpAccessChain %data_storagebuffer_ptr %inputBuffer %int_0\n"
647 << " %output_val_ptr = OpAccessChain %_ptr_StorageBuffer_uint %outputBuffer %int_0\n"
648 // Note we use Volatile to load the input buffer value before and after the call statements.
649 << " %input_val_before = OpLoad ${INPUT_BUFFER_VALUE_TYPE} %input_val_ptr Volatile\n"
650 << "${CALC_ZERO_FOR_CALLABLE}"
651 << "${CALL_STATEMENTS}"
652 << " %input_val_after = OpLoad ${INPUT_BUFFER_VALUE_TYPE} %input_val_ptr Volatile\n"
653 << "${CALC_EQUAL_STATEMENT}"
654 << " %output_val = OpSelect %uint %equal %uint_1 %uint_0\n"
655 << " OpStore %output_val_ptr %output_val\n"
656 << " OpReturn\n"
657 << " OpFunctionEnd\n";
658
659 const tcu::StringTemplate spvTemplate(spvTemplateStream.str());
660
661 std::map<std::string, std::string> subs;
662 std::string componentTypeName;
663 std::string opEqual;
664 const int numComponents = static_cast<int>(m_params.vectorType);
665 const auto isArray = (numComponents > static_cast<int>(VectorType::V4));
666 const auto numComponentsStr = de::toString(numComponents);
667
668 subs["EXTRA_CAPABILITIES"] = "";
669 subs["EXTRA_EXTENSIONS"] = "";
670 subs["EXTRA_TYPES_AND_CONSTANTS"] = "";
671 subs["EXTRA_FUNCTION_VARIABLES"] = "";
672 subs["EXTRA_BINDINGS"] = "";
673 subs["EXTRA_BINDING_VARIABLES"] = "";
674 subs["EXTRA_FUNCTIONS"] = "";
675
676 // Take into account some of these substitutions will be updated after the if-block.
677
678 if (m_params.dataType == DataType::INT32)
679 {
680 componentTypeName = "int";
681
682 subs["INPUT_BUFFER_VALUE_TYPE"] = "%int";
683 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %int_37 = OpConstant %int 37\n";
684 subs["CALC_ZERO_FOR_CALLABLE"] = " %zero_int = OpISub %int %input_val_before %int_37\n"
685 " %zero_for_callable = OpBitcast %uint %zero_int\n";
686 }
687 else if (m_params.dataType == DataType::UINT32)
688 {
689 componentTypeName = "uint";
690
691 subs["INPUT_BUFFER_VALUE_TYPE"] = "%uint";
692 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %uint_37 = OpConstant %uint 37\n";
693 subs["CALC_ZERO_FOR_CALLABLE"] = " %zero_for_callable = OpISub %uint %input_val_before %uint_37\n";
694 }
695 else if (m_params.dataType == DataType::INT64)
696 {
697 componentTypeName = "long";
698
699 subs["EXTRA_CAPABILITIES"] += " OpCapability Int64\n";
700 subs["INPUT_BUFFER_VALUE_TYPE"] = "%long";
701 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %long = OpTypeInt 64 1\n"
702 " %long_37 = OpConstant %long 37\n";
703 subs["CALC_ZERO_FOR_CALLABLE"] = " %zero_long = OpISub %long %input_val_before %long_37\n"
704 " %zero_for_callable = OpSConvert %uint %zero_long\n";
705 }
706 else if (m_params.dataType == DataType::UINT64)
707 {
708 componentTypeName = "ulong";
709
710 subs["EXTRA_CAPABILITIES"] += " OpCapability Int64\n";
711 subs["INPUT_BUFFER_VALUE_TYPE"] = "%ulong";
712 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %ulong = OpTypeInt 64 0\n"
713 " %ulong_37 = OpConstant %ulong 37\n";
714 subs["CALC_ZERO_FOR_CALLABLE"] = " %zero_ulong = OpISub %ulong %input_val_before %ulong_37\n"
715 " %zero_for_callable = OpUConvert %uint %zero_ulong\n";
716 }
717 else if (m_params.dataType == DataType::INT16)
718 {
719 componentTypeName = "short";
720
721 subs["EXTRA_CAPABILITIES"] += " OpCapability Int16\n"
722 " OpCapability StorageBuffer16BitAccess\n";
723 subs["EXTRA_EXTENSIONS"] += " OpExtension \"SPV_KHR_16bit_storage\"\n";
724 subs["INPUT_BUFFER_VALUE_TYPE"] = "%short";
725 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %short = OpTypeInt 16 1\n"
726 " %short_37 = OpConstant %short 37\n";
727 subs["CALC_ZERO_FOR_CALLABLE"] = " %zero_short = OpISub %short %input_val_before %short_37\n"
728 " %zero_for_callable = OpSConvert %uint %zero_short\n";
729 }
730 else if (m_params.dataType == DataType::UINT16)
731 {
732 componentTypeName = "ushort";
733
734 subs["EXTRA_CAPABILITIES"] += " OpCapability Int16\n"
735 " OpCapability StorageBuffer16BitAccess\n";
736 subs["EXTRA_EXTENSIONS"] += " OpExtension \"SPV_KHR_16bit_storage\"\n";
737 subs["INPUT_BUFFER_VALUE_TYPE"] = "%ushort";
738 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %ushort = OpTypeInt 16 0\n"
739 " %ushort_37 = OpConstant %ushort 37\n";
740 subs["CALC_ZERO_FOR_CALLABLE"] =
741 " %zero_ushort = OpISub %ushort %input_val_before %ushort_37\n"
742 " %zero_for_callable = OpUConvert %uint %zero_ushort\n";
743 }
744 else if (m_params.dataType == DataType::INT8)
745 {
746 componentTypeName = "char";
747
748 subs["EXTRA_CAPABILITIES"] += " OpCapability Int8\n"
749 " OpCapability StorageBuffer8BitAccess\n";
750 subs["EXTRA_EXTENSIONS"] += " OpExtension \"SPV_KHR_8bit_storage\"\n";
751 subs["INPUT_BUFFER_VALUE_TYPE"] = "%char";
752 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %char = OpTypeInt 8 1\n"
753 " %char_37 = OpConstant %char 37\n";
754 subs["CALC_ZERO_FOR_CALLABLE"] = " %zero_char = OpISub %char %input_val_before %char_37\n"
755 " %zero_for_callable = OpSConvert %uint %zero_char\n";
756 }
757 else if (m_params.dataType == DataType::UINT8)
758 {
759 componentTypeName = "uchar";
760
761 subs["EXTRA_CAPABILITIES"] += " OpCapability Int8\n"
762 " OpCapability StorageBuffer8BitAccess\n";
763 subs["EXTRA_EXTENSIONS"] += " OpExtension \"SPV_KHR_8bit_storage\"\n";
764 subs["INPUT_BUFFER_VALUE_TYPE"] = "%uchar";
765 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %uchar = OpTypeInt 8 0\n"
766 " %uchar_37 = OpConstant %uchar 37\n";
767 subs["CALC_ZERO_FOR_CALLABLE"] = " %zero_uchar = OpISub %uchar %input_val_before %uchar_37\n"
768 " %zero_for_callable = OpUConvert %uint %zero_uchar\n";
769 }
770 else if (m_params.dataType == DataType::FLOAT32)
771 {
772 componentTypeName = "float";
773
774 subs["INPUT_BUFFER_VALUE_TYPE"] = "%float";
775 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %float_37 = OpConstant %float 37\n";
776 subs["CALC_ZERO_FOR_CALLABLE"] = " %zero_float = OpFSub %float %input_val_before %float_37\n"
777 " %zero_for_callable = OpConvertFToU %uint %zero_float\n";
778 }
779 else if (m_params.dataType == DataType::FLOAT64)
780 {
781 componentTypeName = "double";
782
783 subs["EXTRA_CAPABILITIES"] += " OpCapability Float64\n";
784 subs["INPUT_BUFFER_VALUE_TYPE"] = "%double";
785 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %double = OpTypeFloat 64\n"
786 " %double_37 = OpConstant %double 37\n";
787 subs["CALC_ZERO_FOR_CALLABLE"] =
788 " %zero_double = OpFSub %double %input_val_before %double_37\n"
789 " %zero_for_callable = OpConvertFToU %uint %zero_double\n";
790 }
791 else if (m_params.dataType == DataType::FLOAT16)
792 {
793 componentTypeName = "half";
794
795 subs["EXTRA_CAPABILITIES"] += " OpCapability Float16\n"
796 " OpCapability StorageBuffer16BitAccess\n";
797 subs["EXTRA_EXTENSIONS"] += " OpExtension \"SPV_KHR_16bit_storage\"\n";
798 subs["INPUT_BUFFER_VALUE_TYPE"] = "%half";
799 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %half = OpTypeFloat 16\n"
800 " %half_37 = OpConstant %half 37\n";
801 subs["CALC_ZERO_FOR_CALLABLE"] = " %zero_half = OpFSub %half %input_val_before %half_37\n"
802 " %zero_for_callable = OpConvertFToU %uint %zero_half\n";
803 }
804 else if (m_params.dataType == DataType::STRUCT)
805 {
806 componentTypeName = "InputStruct";
807
808 subs["INPUT_BUFFER_VALUE_TYPE"] = "%InputStruct";
809 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %InputStruct = OpTypeStruct %uint %float\n"
810 " %float_37 = OpConstant %float 37\n"
811 " %uint_part_ptr_type = OpTypePointer StorageBuffer %uint\n"
812 " %float_part_ptr_type = OpTypePointer StorageBuffer %float\n"
813 " %uint_part_func_ptr_type = OpTypePointer Function %uint\n"
814 " %float_part_func_ptr_type = OpTypePointer Function %float\n"
815 " %input_struct_func_ptr_type = OpTypePointer Function %InputStruct\n";
816 subs["INTERFACE_DECORATIONS"] = " OpMemberDecorate %InputStruct 0 Offset 0\n"
817 " OpMemberDecorate %InputStruct 1 Offset 4\n";
818
819 // Sum struct members, then substract constant and convert to uint.
820 subs["CALC_ZERO_FOR_CALLABLE"] =
821 " %uint_part_ptr = OpAccessChain %uint_part_ptr_type %input_val_ptr %uint_0\n"
822 " %float_part_ptr = OpAccessChain %float_part_ptr_type %input_val_ptr %uint_1\n"
823 " %uint_part = OpLoad %uint %uint_part_ptr\n"
824 " %float_part = OpLoad %float %float_part_ptr\n"
825 " %uint_as_float = OpConvertUToF %float %uint_part\n"
826 " %member_sum = OpFAdd %float %float_part %uint_as_float\n"
827 " %zero_float = OpFSub %float %member_sum %float_37\n"
828 " %zero_for_callable = OpConvertFToU %uint %zero_float\n";
829 }
830 else if (samplersNeeded(m_params.dataType))
831 {
832 // These tests will use additional bindings as arrays of 2 elements:
833 // - 1 array of samplers.
834 // - 1 array of images.
835 // - 1 array of combined image samplers.
836 // Input values are typically used as texture coordinates (normally zeros)
837 // Pixels will contain the expected values instead of them being in the input buffer.
838
839 subs["INPUT_BUFFER_VALUE_TYPE"] = "%float";
840 subs["EXTRA_CAPABILITIES"] +=
841 " OpCapability SampledImageArrayNonUniformIndexing\n";
842 subs["EXTRA_EXTENSIONS"] += " OpExtension \"SPV_EXT_descriptor_indexing\"\n";
843 subs["MAIN_INTERFACE_EXTRAS"] += " %sampledTexture %textureSampler %combinedImageSampler";
844 subs["EXTRA_BINDINGS"] += " OpDecorate %sampledTexture DescriptorSet 0\n"
845 " OpDecorate %sampledTexture Binding 4\n"
846 " OpDecorate %textureSampler DescriptorSet 0\n"
847 " OpDecorate %textureSampler Binding 5\n"
848 " OpDecorate %combinedImageSampler DescriptorSet 0\n"
849 " OpDecorate %combinedImageSampler Binding 6\n";
850 subs["EXTRA_TYPES_AND_CONSTANTS"] +=
851 " %uint_37 = OpConstant %uint 37\n"
852 " %v4uint = OpTypeVector %uint 4\n"
853 " %v2float = OpTypeVector %float 2\n"
854 " %image_type = OpTypeImage %uint 2D 0 0 0 1 Unknown\n"
855 " %image_array_type = OpTypeArray %image_type %uint_2\n"
856 " %image_array_type_uniform_ptr = OpTypePointer UniformConstant %image_array_type\n"
857 " %image_type_uniform_ptr = OpTypePointer UniformConstant %image_type\n"
858 " %sampler_type = OpTypeSampler\n"
859 " %sampler_array_type = OpTypeArray %sampler_type %uint_2\n"
860 "%sampler_array_type_uniform_ptr = OpTypePointer UniformConstant %sampler_array_type\n"
861 " %sampler_type_uniform_ptr = OpTypePointer UniformConstant %sampler_type\n"
862 " %sampled_image_type = OpTypeSampledImage %image_type\n"
863 " %sampled_image_array_type = OpTypeArray %sampled_image_type %uint_2\n"
864 "%sampled_image_array_type_uniform_ptr = OpTypePointer UniformConstant %sampled_image_array_type\n"
865 "%sampled_image_type_uniform_ptr = OpTypePointer UniformConstant %sampled_image_type\n";
866 subs["EXTRA_BINDING_VARIABLES"] +=
867 " %sampledTexture = OpVariable %image_array_type_uniform_ptr UniformConstant\n"
868 " %textureSampler = OpVariable %sampler_array_type_uniform_ptr UniformConstant\n"
869 " %combinedImageSampler = OpVariable %sampled_image_array_type_uniform_ptr UniformConstant\n";
870
871 if (m_params.dataType == DataType::IMAGE || m_params.dataType == DataType::SAMPLER)
872 {
873 // Use the first sampler and sample from the first image.
874 subs["CALC_ZERO_FOR_CALLABLE"] +=
875 "%image_0_ptr = OpAccessChain %image_type_uniform_ptr %sampledTexture %uint_0\n"
876 "%sampler_0_ptr = OpAccessChain %sampler_type_uniform_ptr %textureSampler %uint_0\n"
877 "%sampler_0 = OpLoad %sampler_type %sampler_0_ptr\n"
878 "%image_0 = OpLoad %image_type %image_0_ptr\n"
879 "%sampled_image_0 = OpSampledImage %sampled_image_type %image_0 %sampler_0\n"
880 "%texture_coords_0 = OpCompositeConstruct %v2float %input_val_before %input_val_before\n"
881 "%pixel_vec_0 = OpImageSampleExplicitLod %v4uint %sampled_image_0 %texture_coords_0 Lod|ZeroExtend "
882 "%float_0\n"
883 "%pixel_0 = OpCompositeExtract %uint %pixel_vec_0 0\n"
884 "%zero_for_callable = OpISub %uint %pixel_0 %uint_37\n";
885 }
886 else if (m_params.dataType == DataType::SAMPLED_IMAGE)
887 {
888 // Use the first combined image sampler.
889 subs["CALC_ZERO_FOR_CALLABLE"] +=
890 "%sampled_image_0_ptr = OpAccessChain %sampled_image_type_uniform_ptr %combinedImageSampler %uint_0\n"
891 "%sampled_image_0 = OpLoad %sampled_image_type %sampled_image_0_ptr\n"
892 "%texture_coords_0 = OpCompositeConstruct %v2float %input_val_before %input_val_before\n"
893 "%pixel_vec_0 = OpImageSampleExplicitLod %v4uint %sampled_image_0 %texture_coords_0 Lod|ZeroExtend "
894 "%float_0\n"
895 "%pixel_0 = OpCompositeExtract %uint %pixel_vec_0 0\n"
896 "%zero_for_callable = OpISub %uint %pixel_0 %uint_37\n";
897 }
898 else if (m_params.dataType == DataType::PTR_IMAGE)
899 {
900 // We attempt to create the second pointer before the call.
901 subs["CALC_ZERO_FOR_CALLABLE"] +=
902 "%image_0_ptr = OpAccessChain %image_type_uniform_ptr %sampledTexture %uint_0\n"
903 "%image_1_ptr = OpAccessChain %image_type_uniform_ptr %sampledTexture %uint_1\n"
904 "%image_0 = OpLoad %image_type %image_0_ptr\n"
905 "%sampler_0_ptr = OpAccessChain %sampler_type_uniform_ptr %textureSampler %uint_0\n"
906 "%sampler_0 = OpLoad %sampler_type %sampler_0_ptr\n"
907 "%sampled_image_0 = OpSampledImage %sampled_image_type %image_0 %sampler_0\n"
908 "%texture_coords_0 = OpCompositeConstruct %v2float %input_val_before %input_val_before\n"
909 "%pixel_vec_0 = OpImageSampleExplicitLod %v4uint %sampled_image_0 %texture_coords_0 Lod|ZeroExtend "
910 "%float_0\n"
911 "%pixel_0 = OpCompositeExtract %uint %pixel_vec_0 0\n"
912 "%zero_for_callable = OpISub %uint %pixel_0 %uint_37\n";
913 }
914 else if (m_params.dataType == DataType::PTR_SAMPLER)
915 {
916 // We attempt to create the second pointer before the call.
917 subs["CALC_ZERO_FOR_CALLABLE"] +=
918 "%sampler_0_ptr = OpAccessChain %sampler_type_uniform_ptr %textureSampler %uint_0\n"
919 "%sampler_1_ptr = OpAccessChain %sampler_type_uniform_ptr %textureSampler %uint_1\n"
920 "%sampler_0 = OpLoad %sampler_type %sampler_0_ptr\n"
921 "%image_0_ptr = OpAccessChain %image_type_uniform_ptr %sampledTexture %uint_0\n"
922 "%image_0 = OpLoad %image_type %image_0_ptr\n"
923 "%sampled_image_0 = OpSampledImage %sampled_image_type %image_0 %sampler_0\n"
924 "%texture_coords_0 = OpCompositeConstruct %v2float %input_val_before %input_val_before\n"
925 "%pixel_vec_0 = OpImageSampleExplicitLod %v4uint %sampled_image_0 %texture_coords_0 Lod|ZeroExtend "
926 "%float_0\n"
927 "%pixel_0 = OpCompositeExtract %uint %pixel_vec_0 0\n"
928 "%zero_for_callable = OpISub %uint %pixel_0 %uint_37\n";
929 }
930 else if (m_params.dataType == DataType::PTR_SAMPLED_IMAGE)
931 {
932 // We attempt to create the second pointer before the call.
933 subs["CALC_ZERO_FOR_CALLABLE"] +=
934 "%sampled_image_0_ptr = OpAccessChain %sampled_image_type_uniform_ptr %combinedImageSampler %uint_0\n"
935 "%sampled_image_1_ptr = OpAccessChain %sampled_image_type_uniform_ptr %combinedImageSampler %uint_1\n"
936 "%sampled_image_0 = OpLoad %sampled_image_type %sampled_image_0_ptr\n"
937 "%texture_coords_0 = OpCompositeConstruct %v2float %input_val_before %input_val_before\n"
938 "%pixel_vec_0 = OpImageSampleExplicitLod %v4uint %sampled_image_0 %texture_coords_0 Lod|ZeroExtend "
939 "%float_0\n"
940 "%pixel_0 = OpCompositeExtract %uint %pixel_vec_0 0\n"
941 "%zero_for_callable = OpISub %uint %pixel_0 %uint_37\n";
942 }
943 else
944 {
945 DE_ASSERT(false);
946 }
947 }
948 else if (storageImageNeeded(m_params.dataType))
949 {
950 subs["INPUT_BUFFER_VALUE_TYPE"] = "%int";
951 subs["MAIN_INTERFACE_EXTRAS"] += " %storageImage";
952 subs["EXTRA_BINDINGS"] += " OpDecorate %storageImage DescriptorSet 0\n"
953 " OpDecorate %storageImage Binding 4\n";
954 subs["EXTRA_TYPES_AND_CONSTANTS"] +=
955 " %uint_37 = OpConstant %uint 37\n"
956 " %v2int = OpTypeVector %int 2\n"
957 " %image_type = OpTypeImage %uint 2D 0 0 0 2 R32ui\n"
958 " %image_type_uniform_ptr = OpTypePointer UniformConstant %image_type\n"
959 " %uint_img_ptr = OpTypePointer Image %uint\n";
960 subs["EXTRA_BINDING_VARIABLES"] +=
961 " %storageImage = OpVariable %image_type_uniform_ptr UniformConstant\n";
962
963 // Load value from the image, expecting it to be 37 and swapping it with 5.
964 subs["CALC_ZERO_FOR_CALLABLE"] +=
965 "%coords = OpCompositeConstruct %v2int %input_val_before %input_val_before\n"
966 "%texel_ptr = OpImageTexelPointer %uint_img_ptr %storageImage %coords %uint_0\n"
967 "%texel_value = OpAtomicCompareExchange %uint %texel_ptr %uint_1 %uint_0 %uint_0 %uint_5 %uint_37\n"
968 "%zero_for_callable = OpISub %uint %texel_value %uint_37\n";
969 }
970 else if (m_params.dataType == DataType::OP_NULL)
971 {
972 subs["INPUT_BUFFER_VALUE_TYPE"] = "%uint";
973 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %uint_37 = OpConstant %uint 37\n"
974 " %constant_null = OpConstantNull %uint\n";
975
976 // Create a local copy of the null constant global object to work with it.
977 subs["CALC_ZERO_FOR_CALLABLE"] +=
978 "%constant_null_copy = OpCopyObject %uint %constant_null\n"
979 "%is_37_before = OpIEqual %bool %input_val_before %uint_37\n"
980 "%zero_for_callable = OpSelect %uint %is_37_before %constant_null_copy %uint_5\n";
981 }
982 else if (m_params.dataType == DataType::OP_UNDEF)
983 {
984 subs["INPUT_BUFFER_VALUE_TYPE"] = "%uint";
985 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %uint_37 = OpConstant %uint 37\n";
986
987 // Extract an undef value and write it to the output buffer to make sure it's used before the call. The value will be overwritten later.
988 subs["CALC_ZERO_FOR_CALLABLE"] += "%undef_var = OpUndef %uint\n"
989 "%undef_val_before = OpCopyObject %uint %undef_var\n"
990 "OpStore %output_val_ptr %undef_val_before Volatile\n"
991 "%zero_for_callable = OpISub %uint %uint_37 %input_val_before\n";
992 }
993 else
994 {
995 DE_ASSERT(false);
996 }
997
998 // Comparison statement for data before and after the call.
999 switch (m_params.dataType)
1000 {
1001 case DataType::INT32:
1002 case DataType::UINT32:
1003 case DataType::INT64:
1004 case DataType::UINT64:
1005 case DataType::INT16:
1006 case DataType::UINT16:
1007 case DataType::INT8:
1008 case DataType::UINT8:
1009 opEqual = "OpIEqual";
1010 break;
1011 case DataType::FLOAT32:
1012 case DataType::FLOAT64:
1013 case DataType::FLOAT16:
1014 opEqual = "OpFOrdEqual";
1015 break;
1016 case DataType::STRUCT:
1017 case DataType::IMAGE:
1018 case DataType::SAMPLER:
1019 case DataType::SAMPLED_IMAGE:
1020 case DataType::PTR_IMAGE:
1021 case DataType::PTR_SAMPLER:
1022 case DataType::PTR_SAMPLED_IMAGE:
1023 case DataType::PTR_TEXEL:
1024 case DataType::OP_NULL:
1025 case DataType::OP_UNDEF:
1026 // These needs special code for the comparison.
1027 opEqual = "INVALID";
1028 break;
1029 default:
1030 DE_ASSERT(false);
1031 break;
1032 }
1033
1034 if (m_params.dataType == DataType::STRUCT)
1035 {
1036 // We need to store the before and after values in a variable in order to be able to access each member individually without accessing the StorageBuffer again.
1037 subs["EXTRA_FUNCTION_VARIABLES"] =
1038 " %input_val_func_before = OpVariable %input_struct_func_ptr_type Function\n"
1039 " %input_val_func_after = OpVariable %input_struct_func_ptr_type Function\n";
1040 subs["CALC_EQUAL_STATEMENT"] =
1041 " OpStore %input_val_func_before %input_val_before\n"
1042 " OpStore %input_val_func_after %input_val_after\n"
1043 " %uint_part_func_before_ptr = OpAccessChain %uint_part_func_ptr_type %input_val_func_before %uint_0\n"
1044 " %float_part_func_before_ptr = OpAccessChain %float_part_func_ptr_type %input_val_func_before %uint_1\n"
1045 " %uint_part_func_after_ptr = OpAccessChain %uint_part_func_ptr_type %input_val_func_after %uint_0\n"
1046 " %float_part_func_after_ptr = OpAccessChain %float_part_func_ptr_type %input_val_func_after %uint_1\n"
1047 " %uint_part_before = OpLoad %uint %uint_part_func_before_ptr\n"
1048 " %float_part_before = OpLoad %float %float_part_func_before_ptr\n"
1049 " %uint_part_after = OpLoad %uint %uint_part_func_after_ptr\n"
1050 " %float_part_after = OpLoad %float %float_part_func_after_ptr\n"
1051 " %uint_equal = OpIEqual %bool %uint_part_before %uint_part_after\n"
1052 " %float_equal = OpFOrdEqual %bool %float_part_before %float_part_after\n"
1053 " %equal = OpLogicalAnd %bool %uint_equal %float_equal\n";
1054 }
1055 else if (m_params.dataType == DataType::IMAGE)
1056 {
1057 // Use the same image and the second sampler with different coordinates (actually the same).
1058 subs["CALC_EQUAL_STATEMENT"] +=
1059 "%sampler_1_ptr = OpAccessChain %sampler_type_uniform_ptr %textureSampler %uint_1\n"
1060 "%sampler_1 = OpLoad %sampler_type %sampler_1_ptr\n"
1061 "%sampled_image_1 = OpSampledImage %sampled_image_type %image_0 %sampler_1\n"
1062 "%texture_coords_1 = OpCompositeConstruct %v2float %input_val_after %input_val_after\n"
1063 "%pixel_vec_1 = OpImageSampleExplicitLod %v4uint %sampled_image_1 %texture_coords_1 Lod|ZeroExtend "
1064 "%float_0\n"
1065 "%pixel_1 = OpCompositeExtract %uint %pixel_vec_1 0\n"
1066 "%equal = OpIEqual %bool %pixel_0 %pixel_1\n";
1067 }
1068 else if (m_params.dataType == DataType::SAMPLER)
1069 {
1070 // Use the same sampler and sample from the second image with different coordinates (but actually the same).
1071 subs["CALC_EQUAL_STATEMENT"] +=
1072 "%image_1_ptr = OpAccessChain %image_type_uniform_ptr %sampledTexture %uint_1\n"
1073 "%image_1 = OpLoad %image_type %image_1_ptr\n"
1074 "%sampled_image_1 = OpSampledImage %sampled_image_type %image_1 %sampler_0\n"
1075 "%texture_coords_1 = OpCompositeConstruct %v2float %input_val_after %input_val_after\n"
1076 "%pixel_vec_1 = OpImageSampleExplicitLod %v4uint %sampled_image_1 %texture_coords_1 Lod|ZeroExtend "
1077 "%float_0\n"
1078 "%pixel_1 = OpCompositeExtract %uint %pixel_vec_1 0\n"
1079 "%equal = OpIEqual %bool %pixel_0 %pixel_1\n";
1080 }
1081 else if (m_params.dataType == DataType::SAMPLED_IMAGE)
1082 {
1083 // Reuse the same combined image sampler with different coordinates (actually the same).
1084 subs["CALC_EQUAL_STATEMENT"] +=
1085 "%texture_coords_1 = OpCompositeConstruct %v2float %input_val_after %input_val_after\n"
1086 "%pixel_vec_1 = OpImageSampleExplicitLod %v4uint %sampled_image_0 %texture_coords_1 Lod|ZeroExtend "
1087 "%float_0\n"
1088 "%pixel_1 = OpCompositeExtract %uint %pixel_vec_1 0\n"
1089 "%equal = OpIEqual %bool %pixel_0 %pixel_1\n";
1090 }
1091 else if (m_params.dataType == DataType::PTR_IMAGE)
1092 {
1093 // We attempt to use the second pointer only after the call.
1094 subs["CALC_EQUAL_STATEMENT"] +=
1095 "%image_1 = OpLoad %image_type %image_1_ptr\n"
1096 "%sampled_image_1 = OpSampledImage %sampled_image_type %image_1 %sampler_0\n"
1097 "%texture_coords_1 = OpCompositeConstruct %v2float %input_val_after %input_val_after\n"
1098 "%pixel_vec_1 = OpImageSampleExplicitLod %v4uint %sampled_image_1 %texture_coords_1 Lod|ZeroExtend "
1099 "%float_0\n"
1100 "%pixel_1 = OpCompositeExtract %uint %pixel_vec_1 0\n"
1101 "%equal = OpIEqual %bool %pixel_0 %pixel_1\n";
1102 }
1103 else if (m_params.dataType == DataType::PTR_SAMPLER)
1104 {
1105 // We attempt to use the second pointer only after the call.
1106 subs["CALC_EQUAL_STATEMENT"] +=
1107 "%sampler_1 = OpLoad %sampler_type %sampler_1_ptr\n"
1108 "%sampled_image_1 = OpSampledImage %sampled_image_type %image_0 %sampler_1\n"
1109 "%texture_coords_1 = OpCompositeConstruct %v2float %input_val_after %input_val_after\n"
1110 "%pixel_vec_1 = OpImageSampleExplicitLod %v4uint %sampled_image_1 %texture_coords_1 Lod|ZeroExtend "
1111 "%float_0\n"
1112 "%pixel_1 = OpCompositeExtract %uint %pixel_vec_1 0\n"
1113 "%equal = OpIEqual %bool %pixel_0 %pixel_1\n";
1114 }
1115 else if (m_params.dataType == DataType::PTR_SAMPLED_IMAGE)
1116 {
1117 // We attempt to use the second pointer only after the call.
1118 subs["CALC_EQUAL_STATEMENT"] +=
1119 "%sampled_image_1 = OpLoad %sampled_image_type %sampled_image_1_ptr\n"
1120 "%texture_coords_1 = OpCompositeConstruct %v2float %input_val_after %input_val_after\n"
1121 "%pixel_vec_1 = OpImageSampleExplicitLod %v4uint %sampled_image_1 %texture_coords_1 Lod|ZeroExtend "
1122 "%float_0\n"
1123 "%pixel_1 = OpCompositeExtract %uint %pixel_vec_1 0\n"
1124 "%equal = OpIEqual %bool %pixel_0 %pixel_1\n";
1125 }
1126 else if (m_params.dataType == DataType::PTR_TEXEL)
1127 {
1128 // Check value 5 was stored properly.
1129 subs["CALC_EQUAL_STATEMENT"] += "%stored_val = OpAtomicLoad %uint %texel_ptr %uint_1 %uint_0\n"
1130 "%equal = OpIEqual %bool %stored_val %uint_5\n";
1131 }
1132 else if (m_params.dataType == DataType::OP_NULL)
1133 {
1134 // Reuse the null constant after the call.
1135 subs["CALC_EQUAL_STATEMENT"] += "%is_37_after = OpIEqual %bool %input_val_after %uint_37\n"
1136 "%writeback_val = OpSelect %uint %is_37_after %constant_null_copy %uint_5\n"
1137 "OpStore %input_val_ptr %writeback_val Volatile\n"
1138 "%readback_val = OpLoad %uint %input_val_ptr Volatile\n"
1139 "%equal = OpIEqual %bool %readback_val %uint_0\n";
1140 }
1141 else if (m_params.dataType == DataType::OP_UNDEF)
1142 {
1143 // Extract another undef value and write it to the input buffer. It will not be checked later.
1144 subs["CALC_EQUAL_STATEMENT"] += "%undef_val_after = OpCopyObject %uint %undef_var\n"
1145 "OpStore %input_val_ptr %undef_val_after Volatile\n"
1146 "%equal = OpIEqual %bool %input_val_after %input_val_before\n";
1147 }
1148 else
1149 {
1150 subs["CALC_EQUAL_STATEMENT"] +=
1151 " %equal = " + opEqual + " %bool %input_val_before %input_val_after\n";
1152 }
1153
1154 // Modifications for vectors and arrays.
1155 if (numComponents > 1)
1156 {
1157 const std::string vectorTypeName = "v" + numComponentsStr + componentTypeName;
1158 const std::string opType = (isArray ? "OpTypeArray" : "OpTypeVector");
1159 const std::string componentCountStr = (isArray ? ("%uint_" + numComponentsStr) : numComponentsStr);
1160
1161 // Some extra types are needed.
1162 if (!(m_params.dataType == DataType::FLOAT32 && m_params.vectorType == VectorType::V3))
1163 {
1164 // Note: v3float is already defined in the shader by default.
1165 subs["EXTRA_TYPES_AND_CONSTANTS"] +=
1166 "%" + vectorTypeName + " = " + opType + " %" + componentTypeName + " " + componentCountStr + "\n";
1167 }
1168 subs["EXTRA_TYPES_AND_CONSTANTS"] +=
1169 "%v" + numComponentsStr + "bool = " + opType + " %bool " + componentCountStr + "\n";
1170 subs["EXTRA_TYPES_AND_CONSTANTS"] += "%comp_ptr = OpTypePointer StorageBuffer %" + componentTypeName + "\n";
1171
1172 // The input value in the buffer has a different type.
1173 subs["INPUT_BUFFER_VALUE_TYPE"] = "%" + vectorTypeName;
1174
1175 // Overwrite the way we calculate the zero used in the call.
1176
1177 // Proper operations for adding, substracting and converting components.
1178 std::string opAdd;
1179 std::string opSub;
1180 std::string opConvert;
1181
1182 switch (m_params.dataType)
1183 {
1184 case DataType::INT32:
1185 case DataType::UINT32:
1186 case DataType::INT64:
1187 case DataType::UINT64:
1188 case DataType::INT16:
1189 case DataType::UINT16:
1190 case DataType::INT8:
1191 case DataType::UINT8:
1192 opAdd = "OpIAdd";
1193 opSub = "OpISub";
1194 break;
1195 case DataType::FLOAT32:
1196 case DataType::FLOAT64:
1197 case DataType::FLOAT16:
1198 opAdd = "OpFAdd";
1199 opSub = "OpFSub";
1200 break;
1201 default:
1202 DE_ASSERT(false);
1203 break;
1204 }
1205
1206 switch (m_params.dataType)
1207 {
1208 case DataType::UINT32:
1209 opConvert = "OpCopyObject";
1210 break;
1211 case DataType::INT32:
1212 opConvert = "OpBitcast";
1213 break;
1214 case DataType::INT64:
1215 case DataType::INT16:
1216 case DataType::INT8:
1217 opConvert = "OpSConvert";
1218 break;
1219 case DataType::UINT64:
1220 case DataType::UINT16:
1221 case DataType::UINT8:
1222 opConvert = "OpUConvert";
1223 break;
1224 case DataType::FLOAT32:
1225 case DataType::FLOAT64:
1226 case DataType::FLOAT16:
1227 opConvert = "OpConvertFToU";
1228 break;
1229 default:
1230 DE_ASSERT(false);
1231 break;
1232 }
1233
1234 std::ostringstream zeroForCallable;
1235
1236 // Create pointers to components and load components.
1237 for (int i = 0; i < numComponents; ++i)
1238 {
1239 zeroForCallable << "%component_ptr_" << i << " = OpAccessChain %comp_ptr %input_val_ptr %uint_" << i << "\n"
1240 << "%component_" << i << " = OpLoad %" << componentTypeName << " %component_ptr_" << i
1241 << "\n";
1242 }
1243
1244 // Sum components together in %total_sum.
1245 for (int i = 1; i < numComponents; ++i)
1246 {
1247 const std::string previous = ((i == 1) ? "%component_0" : ("%partial_" + de::toString(i - 1)));
1248 const std::string resultName =
1249 ((i == (numComponents - 1)) ? "%total_sum" : ("%partial_" + de::toString(i)));
1250 zeroForCallable << resultName << " = " << opAdd << " %" << componentTypeName << " %component_" << i << " "
1251 << previous << "\n";
1252 }
1253
1254 // Recalculate the zero.
1255 zeroForCallable << "%zero_" << componentTypeName << " = " << opSub << " %" << componentTypeName
1256 << " %total_sum %" << componentTypeName << "_37\n"
1257 << "%zero_for_callable = " << opConvert << " %uint %zero_" << componentTypeName << "\n";
1258
1259 // Finally replace the zero_for_callable statements with the special version for vectors.
1260 subs["CALC_ZERO_FOR_CALLABLE"] = zeroForCallable.str();
1261
1262 // Rework comparison statements.
1263 if (isArray)
1264 {
1265 // Arrays need to be compared per-component.
1266 std::ostringstream calcEqual;
1267
1268 for (int i = 0; i < numComponents; ++i)
1269 {
1270 calcEqual << "%component_after_" << i << " = OpLoad %" << componentTypeName << " %component_ptr_" << i
1271 << "\n"
1272 << "%equal_" << i << " = " << opEqual << " %bool %component_" << i << " %component_after_"
1273 << i << "\n";
1274 if (i > 0)
1275 calcEqual << "%and_" << i << " = OpLogicalAnd %bool %equal_" << (i - 1) << " %equal_" << i << "\n";
1276 if (i == numComponents - 1)
1277 calcEqual << "%equal = OpCopyObject %bool %and_" << i << "\n";
1278 }
1279
1280 subs["CALC_EQUAL_STATEMENT"] = calcEqual.str();
1281 }
1282 else
1283 {
1284 // Vectors can be compared using a bool vector and OpAll.
1285 subs["CALC_EQUAL_STATEMENT"] = " %equal_vector = " + opEqual + " %v" + numComponentsStr +
1286 "bool %input_val_before %input_val_after\n";
1287 subs["CALC_EQUAL_STATEMENT"] += " %equal = OpAll %bool %equal_vector\n";
1288 }
1289 }
1290
1291 if (isArray)
1292 {
1293 // Arrays need an ArrayStride decoration.
1294 std::ostringstream interfaceDecorations;
1295 interfaceDecorations << "OpDecorate %v" << numComponentsStr << componentTypeName << " ArrayStride "
1296 << getElementSize(m_params.dataType, VectorType::SCALAR) << "\n";
1297 subs["INTERFACE_DECORATIONS"] = interfaceDecorations.str();
1298 }
1299
1300 const auto inputBlockDecls = getGLSLInputValDecl(m_params.dataType, m_params.vectorType);
1301
1302 std::ostringstream glslBindings;
1303 glslBindings << inputBlockDecls.first // Additional data types needed.
1304 << "layout(set = 0, binding = 0) uniform accelerationStructureEXT topLevelAS;\n"
1305 << "layout(set = 0, binding = 1) buffer CalleeBlock { uint val; } calleeBuffer;\n"
1306 << "layout(set = 0, binding = 2) buffer OutputBlock { uint val; } outputBuffer;\n"
1307 << "layout(set = 0, binding = 3) buffer InputBlock { " << inputBlockDecls.second
1308 << " } inputBuffer;\n";
1309
1310 if (samplersNeeded(m_params.dataType))
1311 {
1312 glslBindings << "layout(set = 0, binding = 4) uniform utexture2D sampledTexture[2];\n"
1313 << "layout(set = 0, binding = 5) uniform sampler textureSampler[2];\n"
1314 << "layout(set = 0, binding = 6) uniform usampler2D combinedImageSampler[2];\n";
1315 }
1316 else if (storageImageNeeded(m_params.dataType))
1317 {
1318 glslBindings << "layout(set = 0, binding = 4, r32ui) uniform uimage2D storageImage;\n";
1319 }
1320
1321 const auto glslBindingsStr = glslBindings.str();
1322 const auto glslHeaderStr = "#version 460 core\n"
1323 "#extension GL_EXT_ray_tracing : require\n"
1324 "#extension GL_EXT_shader_explicit_arithmetic_types : require\n";
1325
1326 if (m_params.callType == CallType::TRACE_RAY)
1327 {
1328 subs["ENTRY_POINT"] = "RayGenerationKHR";
1329 subs["MAIN_INTERFACE_EXTRAS"] += " %hitValue";
1330 subs["INTERFACE_DECORATIONS"] += " OpDecorate %hitValue Location 0\n";
1331 subs["INTERFACE_TYPES_AND_VARIABLES"] =
1332 " %payload_ptr = OpTypePointer RayPayloadKHR %v3float\n"
1333 " %hitValue = OpVariable %payload_ptr RayPayloadKHR\n";
1334 subs["CALL_STATEMENTS"] =
1335 " %as_value = OpLoad %as_type %topLevelAS\n"
1336 " OpTraceRayKHR %as_value %uint_0 %uint_255 %zero_for_callable "
1337 "%zero_for_callable %zero_for_callable %origin_const %float_0 %direction_const %float_9 %hitValue\n";
1338
1339 const auto rgen = spvTemplate.specialize(subs);
1340 programCollection.spirvAsmSources.add("rgen") << rgen << spvBuildOptions;
1341
1342 std::stringstream chit;
1343 chit << glslHeaderStr << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1344 << "hitAttributeEXT vec3 attribs;\n"
1345 << glslBindingsStr << "void main()\n"
1346 << "{\n"
1347 << " calleeBuffer.val = 1u;\n"
1348 << "}\n";
1349 programCollection.glslSources.add("chit")
1350 << glu::ClosestHitSource(updateRayTracingGLSL(chit.str())) << buildOptions;
1351 }
1352 else if (m_params.callType == CallType::EXECUTE_CALLABLE)
1353 {
1354 subs["ENTRY_POINT"] = "RayGenerationKHR";
1355 subs["MAIN_INTERFACE_EXTRAS"] += " %callableData";
1356 subs["INTERFACE_DECORATIONS"] += " OpDecorate %callableData Location 0\n";
1357 subs["INTERFACE_TYPES_AND_VARIABLES"] =
1358 " %callable_data_ptr = OpTypePointer CallableDataKHR %float\n"
1359 " %callableData = OpVariable %callable_data_ptr CallableDataKHR\n";
1360 subs["CALL_STATEMENTS"] =
1361 " OpExecuteCallableKHR %zero_for_callable %callableData\n";
1362
1363 const auto rgen = spvTemplate.specialize(subs);
1364 programCollection.spirvAsmSources.add("rgen") << rgen << spvBuildOptions;
1365
1366 std::ostringstream call;
1367 call << glslHeaderStr << "layout(location = 0) callableDataInEXT float callableData;\n"
1368 << glslBindingsStr << "void main()\n"
1369 << "{\n"
1370 << " calleeBuffer.val = 1u;\n"
1371 << "}\n";
1372
1373 programCollection.glslSources.add("call")
1374 << glu::CallableSource(updateRayTracingGLSL(call.str())) << buildOptions;
1375 }
1376 else if (m_params.callType == CallType::REPORT_INTERSECTION)
1377 {
1378 subs["ENTRY_POINT"] = "IntersectionKHR";
1379 subs["MAIN_INTERFACE_EXTRAS"] += " %attribs";
1380 subs["INTERFACE_DECORATIONS"] += "";
1381 subs["INTERFACE_TYPES_AND_VARIABLES"] =
1382 " %hit_attribute_ptr = OpTypePointer HitAttributeKHR %v3float\n"
1383 " %attribs = OpVariable %hit_attribute_ptr HitAttributeKHR\n";
1384 subs["CALL_STATEMENTS"] =
1385 " %intersection_ret = OpReportIntersectionKHR %bool %float_1 %zero_for_callable\n";
1386
1387 const auto rint = spvTemplate.specialize(subs);
1388 programCollection.spirvAsmSources.add("rint") << rint << spvBuildOptions;
1389
1390 std::ostringstream rgen;
1391 rgen << glslHeaderStr << "layout(location = 0) rayPayloadEXT vec3 hitValue;\n"
1392 << glslBindingsStr << "void main()\n"
1393 << "{\n"
1394 << " traceRayEXT(topLevelAS, 0u, 0xFFu, 0, 0, 0, vec3(0.5, 0.5, 0.0), 0.0, vec3(0.0, 0.0, -1.0), 9.0, "
1395 "0);\n"
1396 << "}\n";
1397 programCollection.glslSources.add("rgen")
1398 << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
1399
1400 std::stringstream ahit;
1401 ahit << glslHeaderStr << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1402 << "hitAttributeEXT vec3 attribs;\n"
1403 << glslBindingsStr << "void main()\n"
1404 << "{\n"
1405 << " calleeBuffer.val = 1u;\n"
1406 << "}\n";
1407 programCollection.glslSources.add("ahit")
1408 << glu::AnyHitSource(updateRayTracingGLSL(ahit.str())) << buildOptions;
1409 }
1410 else
1411 {
1412 DE_ASSERT(false);
1413 }
1414 }
1415
1416 using v2i32 = tcu::Vector<int32_t, 2>;
1417 using v3i32 = tcu::Vector<int32_t, 3>;
1418 using v4i32 = tcu::Vector<int32_t, 4>;
1419 using a5i32 = std::array<int32_t, 5>;
1420
1421 using v2u32 = tcu::Vector<uint32_t, 2>;
1422 using v3u32 = tcu::Vector<uint32_t, 3>;
1423 using v4u32 = tcu::Vector<uint32_t, 4>;
1424 using a5u32 = std::array<uint32_t, 5>;
1425
1426 using v2i64 = tcu::Vector<int64_t, 2>;
1427 using v3i64 = tcu::Vector<int64_t, 3>;
1428 using v4i64 = tcu::Vector<int64_t, 4>;
1429 using a5i64 = std::array<int64_t, 5>;
1430
1431 using v2u64 = tcu::Vector<uint64_t, 2>;
1432 using v3u64 = tcu::Vector<uint64_t, 3>;
1433 using v4u64 = tcu::Vector<uint64_t, 4>;
1434 using a5u64 = std::array<uint64_t, 5>;
1435
1436 using v2i16 = tcu::Vector<int16_t, 2>;
1437 using v3i16 = tcu::Vector<int16_t, 3>;
1438 using v4i16 = tcu::Vector<int16_t, 4>;
1439 using a5i16 = std::array<int16_t, 5>;
1440
1441 using v2u16 = tcu::Vector<uint16_t, 2>;
1442 using v3u16 = tcu::Vector<uint16_t, 3>;
1443 using v4u16 = tcu::Vector<uint16_t, 4>;
1444 using a5u16 = std::array<uint16_t, 5>;
1445
1446 using v2i8 = tcu::Vector<int8_t, 2>;
1447 using v3i8 = tcu::Vector<int8_t, 3>;
1448 using v4i8 = tcu::Vector<int8_t, 4>;
1449 using a5i8 = std::array<int8_t, 5>;
1450
1451 using v2u8 = tcu::Vector<uint8_t, 2>;
1452 using v3u8 = tcu::Vector<uint8_t, 3>;
1453 using v4u8 = tcu::Vector<uint8_t, 4>;
1454 using a5u8 = std::array<uint8_t, 5>;
1455
1456 using v2f32 = tcu::Vector<tcu::Float32, 2>;
1457 using v3f32 = tcu::Vector<tcu::Float32, 3>;
1458 using v4f32 = tcu::Vector<tcu::Float32, 4>;
1459 using a5f32 = std::array<tcu::Float32, 5>;
1460
1461 using v2f64 = tcu::Vector<tcu::Float64, 2>;
1462 using v3f64 = tcu::Vector<tcu::Float64, 3>;
1463 using v4f64 = tcu::Vector<tcu::Float64, 4>;
1464 using a5f64 = std::array<tcu::Float64, 5>;
1465
1466 using v2f16 = tcu::Vector<tcu::Float16, 2>;
1467 using v3f16 = tcu::Vector<tcu::Float16, 3>;
1468 using v4f16 = tcu::Vector<tcu::Float16, 4>;
1469 using a5f16 = std::array<tcu::Float16, 5>;
1470
1471 // Scalar types get filled with value 37, matching the value that will be substracted in the shader.
1472 #define GEN_SCALAR_FILL(DATA_TYPE) \
1473 do \
1474 { \
1475 const auto inputBufferValue = static_cast<DATA_TYPE>(37.0); \
1476 deMemcpy(bufferPtr, &inputBufferValue, sizeof(inputBufferValue)); \
1477 } while (0)
1478
1479 // Vector types get filled with values that add up to 37, matching the value that will be substracted in the shader.
1480 #define GEN_V2_FILL(DATA_TYPE) \
1481 do \
1482 { \
1483 DATA_TYPE inputBufferValue; \
1484 inputBufferValue.x() = static_cast<DATA_TYPE::Element>(21.0); \
1485 inputBufferValue.y() = static_cast<DATA_TYPE::Element>(16.0); \
1486 deMemcpy(bufferPtr, &inputBufferValue, sizeof(inputBufferValue)); \
1487 } while (0)
1488
1489 #define GEN_V3_FILL(DATA_TYPE) \
1490 do \
1491 { \
1492 DATA_TYPE inputBufferValue; \
1493 inputBufferValue.x() = static_cast<DATA_TYPE::Element>(11.0); \
1494 inputBufferValue.y() = static_cast<DATA_TYPE::Element>(19.0); \
1495 inputBufferValue.z() = static_cast<DATA_TYPE::Element>(7.0); \
1496 deMemcpy(bufferPtr, &inputBufferValue, sizeof(inputBufferValue)); \
1497 } while (0)
1498
1499 #define GEN_V4_FILL(DATA_TYPE) \
1500 do \
1501 { \
1502 DATA_TYPE inputBufferValue; \
1503 inputBufferValue.x() = static_cast<DATA_TYPE::Element>(9.0); \
1504 inputBufferValue.y() = static_cast<DATA_TYPE::Element>(11.0); \
1505 inputBufferValue.z() = static_cast<DATA_TYPE::Element>(3.0); \
1506 inputBufferValue.w() = static_cast<DATA_TYPE::Element>(14.0); \
1507 deMemcpy(bufferPtr, &inputBufferValue, sizeof(inputBufferValue)); \
1508 } while (0)
1509
1510 #define GEN_A5_FILL(DATA_TYPE) \
1511 do \
1512 { \
1513 DATA_TYPE inputBufferValue; \
1514 inputBufferValue[0] = static_cast<DATA_TYPE::value_type>(13.0); \
1515 inputBufferValue[1] = static_cast<DATA_TYPE::value_type>(6.0); \
1516 inputBufferValue[2] = static_cast<DATA_TYPE::value_type>(2.0); \
1517 inputBufferValue[3] = static_cast<DATA_TYPE::value_type>(5.0); \
1518 inputBufferValue[4] = static_cast<DATA_TYPE::value_type>(11.0); \
1519 deMemcpy(bufferPtr, inputBufferValue.data(), de::dataSize(inputBufferValue)); \
1520 } while (0)
1521
fillInputBuffer(DataType dataType,VectorType vectorType,void * bufferPtr)1522 void fillInputBuffer(DataType dataType, VectorType vectorType, void *bufferPtr)
1523 {
1524 if (vectorType == VectorType::SCALAR)
1525 {
1526 if (dataType == DataType::INT32)
1527 GEN_SCALAR_FILL(int32_t);
1528 else if (dataType == DataType::UINT32)
1529 GEN_SCALAR_FILL(uint32_t);
1530 else if (dataType == DataType::INT64)
1531 GEN_SCALAR_FILL(int64_t);
1532 else if (dataType == DataType::UINT64)
1533 GEN_SCALAR_FILL(uint64_t);
1534 else if (dataType == DataType::INT16)
1535 GEN_SCALAR_FILL(int16_t);
1536 else if (dataType == DataType::UINT16)
1537 GEN_SCALAR_FILL(uint16_t);
1538 else if (dataType == DataType::INT8)
1539 GEN_SCALAR_FILL(int8_t);
1540 else if (dataType == DataType::UINT8)
1541 GEN_SCALAR_FILL(uint8_t);
1542 else if (dataType == DataType::FLOAT32)
1543 GEN_SCALAR_FILL(tcu::Float32);
1544 else if (dataType == DataType::FLOAT64)
1545 GEN_SCALAR_FILL(tcu::Float64);
1546 else if (dataType == DataType::FLOAT16)
1547 GEN_SCALAR_FILL(tcu::Float16);
1548 else if (dataType == DataType::STRUCT)
1549 {
1550 InputStruct data = {12u, 25.0f};
1551 deMemcpy(bufferPtr, &data, sizeof(data));
1552 }
1553 else if (dataType == DataType::OP_NULL)
1554 GEN_SCALAR_FILL(uint32_t);
1555 else if (dataType == DataType::OP_UNDEF)
1556 GEN_SCALAR_FILL(uint32_t);
1557 else
1558 {
1559 DE_ASSERT(false);
1560 }
1561 }
1562 else if (vectorType == VectorType::V2)
1563 {
1564 if (dataType == DataType::INT32)
1565 GEN_V2_FILL(v2i32);
1566 else if (dataType == DataType::UINT32)
1567 GEN_V2_FILL(v2u32);
1568 else if (dataType == DataType::INT64)
1569 GEN_V2_FILL(v2i64);
1570 else if (dataType == DataType::UINT64)
1571 GEN_V2_FILL(v2u64);
1572 else if (dataType == DataType::INT16)
1573 GEN_V2_FILL(v2i16);
1574 else if (dataType == DataType::UINT16)
1575 GEN_V2_FILL(v2u16);
1576 else if (dataType == DataType::INT8)
1577 GEN_V2_FILL(v2i8);
1578 else if (dataType == DataType::UINT8)
1579 GEN_V2_FILL(v2u8);
1580 else if (dataType == DataType::FLOAT32)
1581 GEN_V2_FILL(v2f32);
1582 else if (dataType == DataType::FLOAT64)
1583 GEN_V2_FILL(v2f64);
1584 else if (dataType == DataType::FLOAT16)
1585 GEN_V2_FILL(v2f16);
1586 else
1587 {
1588 DE_ASSERT(false);
1589 }
1590 }
1591 else if (vectorType == VectorType::V3)
1592 {
1593 if (dataType == DataType::INT32)
1594 GEN_V3_FILL(v3i32);
1595 else if (dataType == DataType::UINT32)
1596 GEN_V3_FILL(v3u32);
1597 else if (dataType == DataType::INT64)
1598 GEN_V3_FILL(v3i64);
1599 else if (dataType == DataType::UINT64)
1600 GEN_V3_FILL(v3u64);
1601 else if (dataType == DataType::INT16)
1602 GEN_V3_FILL(v3i16);
1603 else if (dataType == DataType::UINT16)
1604 GEN_V3_FILL(v3u16);
1605 else if (dataType == DataType::INT8)
1606 GEN_V3_FILL(v3i8);
1607 else if (dataType == DataType::UINT8)
1608 GEN_V3_FILL(v3u8);
1609 else if (dataType == DataType::FLOAT32)
1610 GEN_V3_FILL(v3f32);
1611 else if (dataType == DataType::FLOAT64)
1612 GEN_V3_FILL(v3f64);
1613 else if (dataType == DataType::FLOAT16)
1614 GEN_V3_FILL(v3f16);
1615 else
1616 {
1617 DE_ASSERT(false);
1618 }
1619 }
1620 else if (vectorType == VectorType::V4)
1621 {
1622 if (dataType == DataType::INT32)
1623 GEN_V4_FILL(v4i32);
1624 else if (dataType == DataType::UINT32)
1625 GEN_V4_FILL(v4u32);
1626 else if (dataType == DataType::INT64)
1627 GEN_V4_FILL(v4i64);
1628 else if (dataType == DataType::UINT64)
1629 GEN_V4_FILL(v4u64);
1630 else if (dataType == DataType::INT16)
1631 GEN_V4_FILL(v4i16);
1632 else if (dataType == DataType::UINT16)
1633 GEN_V4_FILL(v4u16);
1634 else if (dataType == DataType::INT8)
1635 GEN_V4_FILL(v4i8);
1636 else if (dataType == DataType::UINT8)
1637 GEN_V4_FILL(v4u8);
1638 else if (dataType == DataType::FLOAT32)
1639 GEN_V4_FILL(v4f32);
1640 else if (dataType == DataType::FLOAT64)
1641 GEN_V4_FILL(v4f64);
1642 else if (dataType == DataType::FLOAT16)
1643 GEN_V4_FILL(v4f16);
1644 else
1645 {
1646 DE_ASSERT(false);
1647 }
1648 }
1649 else if (vectorType == VectorType::A5)
1650 {
1651 if (dataType == DataType::INT32)
1652 GEN_A5_FILL(a5i32);
1653 else if (dataType == DataType::UINT32)
1654 GEN_A5_FILL(a5u32);
1655 else if (dataType == DataType::INT64)
1656 GEN_A5_FILL(a5i64);
1657 else if (dataType == DataType::UINT64)
1658 GEN_A5_FILL(a5u64);
1659 else if (dataType == DataType::INT16)
1660 GEN_A5_FILL(a5i16);
1661 else if (dataType == DataType::UINT16)
1662 GEN_A5_FILL(a5u16);
1663 else if (dataType == DataType::INT8)
1664 GEN_A5_FILL(a5i8);
1665 else if (dataType == DataType::UINT8)
1666 GEN_A5_FILL(a5u8);
1667 else if (dataType == DataType::FLOAT32)
1668 GEN_A5_FILL(a5f32);
1669 else if (dataType == DataType::FLOAT64)
1670 GEN_A5_FILL(a5f64);
1671 else if (dataType == DataType::FLOAT16)
1672 GEN_A5_FILL(a5f16);
1673 else
1674 {
1675 DE_ASSERT(false);
1676 }
1677 }
1678 else
1679 {
1680 DE_ASSERT(false);
1681 }
1682 }
1683
iterate(void)1684 tcu::TestStatus DataSpillTestInstance::iterate(void)
1685 {
1686 const auto &vki = m_context.getInstanceInterface();
1687 const auto physicalDevice = m_context.getPhysicalDevice();
1688 const auto &vkd = m_context.getDeviceInterface();
1689 const auto device = m_context.getDevice();
1690 const auto queue = m_context.getUniversalQueue();
1691 const auto familyIndex = m_context.getUniversalQueueFamilyIndex();
1692 auto &alloc = m_context.getDefaultAllocator();
1693 const auto shaderStages = getShaderStages(m_params.callType);
1694
1695 // Command buffer.
1696 const auto cmdPool = makeCommandPool(vkd, device, familyIndex);
1697 const auto cmdBufferPtr = allocateCommandBuffer(vkd, device, cmdPool.get(), VK_COMMAND_BUFFER_LEVEL_PRIMARY);
1698 const auto cmdBuffer = cmdBufferPtr.get();
1699
1700 beginCommandBuffer(vkd, cmdBuffer);
1701
1702 // Callee, input and output buffers.
1703 const auto calleeBufferSize = getElementSize(DataType::UINT32, VectorType::SCALAR);
1704 const auto outputBufferSize = getElementSize(DataType::UINT32, VectorType::SCALAR);
1705 const auto inputBufferSize = getElementSize(m_params.dataType, m_params.vectorType);
1706
1707 const auto calleeBufferInfo = makeBufferCreateInfo(calleeBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
1708 const auto outputBufferInfo = makeBufferCreateInfo(outputBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
1709 const auto inputBufferInfo = makeBufferCreateInfo(inputBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
1710
1711 BufferWithMemory calleeBuffer(vkd, device, alloc, calleeBufferInfo, MemoryRequirement::HostVisible);
1712 BufferWithMemory outputBuffer(vkd, device, alloc, outputBufferInfo, MemoryRequirement::HostVisible);
1713 BufferWithMemory inputBuffer(vkd, device, alloc, inputBufferInfo, MemoryRequirement::HostVisible);
1714
1715 // Fill buffers with values.
1716 auto &calleeBufferAlloc = calleeBuffer.getAllocation();
1717 auto *calleeBufferPtr = calleeBufferAlloc.getHostPtr();
1718 auto &outputBufferAlloc = outputBuffer.getAllocation();
1719 auto *outputBufferPtr = outputBufferAlloc.getHostPtr();
1720 auto &inputBufferAlloc = inputBuffer.getAllocation();
1721 auto *inputBufferPtr = inputBufferAlloc.getHostPtr();
1722
1723 deMemset(calleeBufferPtr, 0, static_cast<size_t>(calleeBufferSize));
1724 deMemset(outputBufferPtr, 0, static_cast<size_t>(outputBufferSize));
1725
1726 if (samplersNeeded(m_params.dataType) || storageImageNeeded(m_params.dataType))
1727 {
1728 // The input buffer for these cases will be filled with zeros (sampling coordinates), and the input textures will contain the interesting input value.
1729 deMemset(inputBufferPtr, 0, static_cast<size_t>(inputBufferSize));
1730 }
1731 else
1732 {
1733 // We want to fill the input buffer with values that will be consistently used in the shader to obtain a result of zero.
1734 fillInputBuffer(m_params.dataType, m_params.vectorType, inputBufferPtr);
1735 }
1736
1737 flushAlloc(vkd, device, calleeBufferAlloc);
1738 flushAlloc(vkd, device, outputBufferAlloc);
1739 flushAlloc(vkd, device, inputBufferAlloc);
1740
1741 // Acceleration structures.
1742 de::MovePtr<BottomLevelAccelerationStructure> bottomLevelAccelerationStructure;
1743 de::MovePtr<TopLevelAccelerationStructure> topLevelAccelerationStructure;
1744
1745 bottomLevelAccelerationStructure = makeBottomLevelAccelerationStructure();
1746 bottomLevelAccelerationStructure->setDefaultGeometryData(getShaderStageForGeometry(m_params.callType),
1747 VK_GEOMETRY_NO_DUPLICATE_ANY_HIT_INVOCATION_BIT_KHR);
1748 bottomLevelAccelerationStructure->createAndBuild(vkd, device, cmdBuffer, alloc);
1749
1750 topLevelAccelerationStructure = makeTopLevelAccelerationStructure();
1751 topLevelAccelerationStructure->setInstanceCount(1);
1752 topLevelAccelerationStructure->addInstance(
1753 de::SharedPtr<BottomLevelAccelerationStructure>(bottomLevelAccelerationStructure.release()));
1754 topLevelAccelerationStructure->createAndBuild(vkd, device, cmdBuffer, alloc);
1755
1756 // Get some ray tracing properties.
1757 uint32_t shaderGroupHandleSize = 0u;
1758 uint32_t shaderGroupBaseAlignment = 1u;
1759 {
1760 const auto rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
1761 shaderGroupHandleSize = rayTracingPropertiesKHR->getShaderGroupHandleSize();
1762 shaderGroupBaseAlignment = rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
1763 }
1764
1765 // Textures and samplers if needed.
1766 de::MovePtr<BufferWithMemory> textureData;
1767 std::vector<de::MovePtr<ImageWithMemory>> textures;
1768 std::vector<Move<VkImageView>> textureViews;
1769 std::vector<Move<VkSampler>> samplers;
1770
1771 if (samplersNeeded(m_params.dataType) || storageImageNeeded(m_params.dataType))
1772 {
1773 // Create texture data with the expected contents.
1774 {
1775 const auto textureDataSize = static_cast<VkDeviceSize>(sizeof(uint32_t));
1776 const auto textureDataCreateInfo = makeBufferCreateInfo(textureDataSize, VK_BUFFER_USAGE_TRANSFER_SRC_BIT);
1777
1778 textureData = de::MovePtr<BufferWithMemory>(
1779 new BufferWithMemory(vkd, device, alloc, textureDataCreateInfo, MemoryRequirement::HostVisible));
1780 auto &textureDataAlloc = textureData->getAllocation();
1781 auto *textureDataPtr = textureDataAlloc.getHostPtr();
1782
1783 fillInputBuffer(DataType::UINT32, VectorType::SCALAR, textureDataPtr);
1784 flushAlloc(vkd, device, textureDataAlloc);
1785 }
1786
1787 // Images will be created like this with different usages.
1788 VkImageCreateInfo imageCreateInfo = {
1789 VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // VkStructureType sType;
1790 nullptr, // const void* pNext;
1791 0u, // VkImageCreateFlags flags;
1792 VK_IMAGE_TYPE_2D, // VkImageType imageType;
1793 kImageFormat, // VkFormat format;
1794 kImageExtent, // VkExtent3D extent;
1795 1u, // uint32_t mipLevels;
1796 1u, // uint32_t arrayLayers;
1797 VK_SAMPLE_COUNT_1_BIT, // VkSampleCountFlagBits samples;
1798 VK_IMAGE_TILING_OPTIMAL, // VkImageTiling tiling;
1799 kSampledImageUsage, // VkImageUsageFlags usage;
1800 VK_SHARING_MODE_EXCLUSIVE, // VkSharingMode sharingMode;
1801 0u, // uint32_t queueFamilyIndexCount;
1802 nullptr, // const uint32_t* pQueueFamilyIndices;
1803 VK_IMAGE_LAYOUT_UNDEFINED, // VkImageLayout initialLayout;
1804 };
1805
1806 const auto imageSubresourceRange = makeImageSubresourceRange(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 1u, 0u, 1u);
1807 const auto imageSubresourceLayers = makeImageSubresourceLayers(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 0u, 1u);
1808
1809 if (samplersNeeded(m_params.dataType))
1810 {
1811 // All samplers will be created like this.
1812 const VkSamplerCreateInfo samplerCreateInfo = {
1813 VK_STRUCTURE_TYPE_SAMPLER_CREATE_INFO, // VkStructureType sType;
1814 nullptr, // const void* pNext;
1815 0u, // VkSamplerCreateFlags flags;
1816 VK_FILTER_NEAREST, // VkFilter magFilter;
1817 VK_FILTER_NEAREST, // VkFilter minFilter;
1818 VK_SAMPLER_MIPMAP_MODE_NEAREST, // VkSamplerMipmapMode mipmapMode;
1819 VK_SAMPLER_ADDRESS_MODE_REPEAT, // VkSamplerAddressMode addressModeU;
1820 VK_SAMPLER_ADDRESS_MODE_REPEAT, // VkSamplerAddressMode addressModeV;
1821 VK_SAMPLER_ADDRESS_MODE_REPEAT, // VkSamplerAddressMode addressModeW;
1822 0.0, // float mipLodBias;
1823 VK_FALSE, // VkBool32 anisotropyEnable;
1824 1.0f, // float maxAnisotropy;
1825 VK_FALSE, // VkBool32 compareEnable;
1826 VK_COMPARE_OP_ALWAYS, // VkCompareOp compareOp;
1827 0.0f, // float minLod;
1828 1.0f, // float maxLod;
1829 VK_BORDER_COLOR_INT_OPAQUE_BLACK, // VkBorderColor borderColor;
1830 VK_FALSE, // VkBool32 unnormalizedCoordinates;
1831 };
1832
1833 // Create textures and samplers.
1834 for (size_t i = 0; i < kNumImages; ++i)
1835 {
1836 textures.emplace_back(new ImageWithMemory(vkd, device, alloc, imageCreateInfo, MemoryRequirement::Any));
1837 textureViews.emplace_back(makeImageView(vkd, device, textures.back()->get(), VK_IMAGE_VIEW_TYPE_2D,
1838 kImageFormat, imageSubresourceRange));
1839 }
1840
1841 for (size_t i = 0; i < kNumSamplers; ++i)
1842 samplers.emplace_back(createSampler(vkd, device, &samplerCreateInfo));
1843
1844 // Make sure texture data is available in the transfer stage.
1845 const auto textureDataBarrier = makeMemoryBarrier(VK_ACCESS_HOST_WRITE_BIT, VK_ACCESS_TRANSFER_READ_BIT);
1846 vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_HOST_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, 0u, 1u,
1847 &textureDataBarrier, 0u, nullptr, 0u, nullptr);
1848
1849 const auto bufferImageCopy = makeBufferImageCopy(kImageExtent, imageSubresourceLayers);
1850
1851 // Fill textures with data and prepare them for the ray tracing pipeline stages.
1852 for (size_t i = 0; i < kNumImages; ++i)
1853 {
1854 const auto texturePreCopyBarrier = makeImageMemoryBarrier(
1855 0u, VK_ACCESS_TRANSFER_WRITE_BIT, VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
1856 textures[i]->get(), imageSubresourceRange);
1857 const auto texturePostCopyBarrier = makeImageMemoryBarrier(
1858 VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_SHADER_READ_BIT, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
1859 VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, textures[i]->get(), imageSubresourceRange);
1860
1861 vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, 0u,
1862 0u, nullptr, 0u, nullptr, 1u, &texturePreCopyBarrier);
1863 vkd.cmdCopyBufferToImage(cmdBuffer, textureData->get(), textures[i]->get(),
1864 VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 1u, &bufferImageCopy);
1865 vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT,
1866 VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR, 0u, 0u, nullptr, 0u, nullptr, 1u,
1867 &texturePostCopyBarrier);
1868 }
1869 }
1870 else if (storageImageNeeded(m_params.dataType))
1871 {
1872 // Image will be used for storage.
1873 imageCreateInfo.usage = kStorageImageUsage;
1874
1875 textures.emplace_back(new ImageWithMemory(vkd, device, alloc, imageCreateInfo, MemoryRequirement::Any));
1876 textureViews.emplace_back(makeImageView(vkd, device, textures.back()->get(), VK_IMAGE_VIEW_TYPE_2D,
1877 kImageFormat, imageSubresourceRange));
1878
1879 // Make sure texture data is available in the transfer stage.
1880 const auto textureDataBarrier = makeMemoryBarrier(VK_ACCESS_HOST_WRITE_BIT, VK_ACCESS_TRANSFER_READ_BIT);
1881 vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_HOST_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, 0u, 1u,
1882 &textureDataBarrier, 0u, nullptr, 0u, nullptr);
1883
1884 const auto bufferImageCopy = makeBufferImageCopy(kImageExtent, imageSubresourceLayers);
1885 const auto texturePreCopyBarrier = makeImageMemoryBarrier(
1886 0u, VK_ACCESS_TRANSFER_WRITE_BIT, VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
1887 textures.back()->get(), imageSubresourceRange);
1888 const auto texturePostCopyBarrier = makeImageMemoryBarrier(
1889 VK_ACCESS_TRANSFER_WRITE_BIT, (VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT),
1890 VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, VK_IMAGE_LAYOUT_GENERAL, textures.back()->get(),
1891 imageSubresourceRange);
1892
1893 // Fill texture with data and prepare them for the ray tracing pipeline stages.
1894 vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, 0u, 0u,
1895 nullptr, 0u, nullptr, 1u, &texturePreCopyBarrier);
1896 vkd.cmdCopyBufferToImage(cmdBuffer, textureData->get(), textures.back()->get(),
1897 VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 1u, &bufferImageCopy);
1898 vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT,
1899 VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR, 0u, 0u, nullptr, 0u, nullptr, 1u,
1900 &texturePostCopyBarrier);
1901 }
1902 else
1903 {
1904 DE_ASSERT(false);
1905 }
1906 }
1907
1908 // Descriptor set layout.
1909 DescriptorSetLayoutBuilder dslBuilder;
1910 dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, 1u, shaderStages, nullptr);
1911 dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1u, shaderStages, nullptr); // Callee buffer.
1912 dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1u, shaderStages, nullptr); // Output buffer.
1913 dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1u, shaderStages, nullptr); // Input buffer.
1914 if (samplersNeeded(m_params.dataType))
1915 {
1916 dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE, 2u, shaderStages, nullptr);
1917 dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_SAMPLER, 2u, shaderStages, nullptr);
1918 dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, 2u, shaderStages, nullptr);
1919 }
1920 else if (storageImageNeeded(m_params.dataType))
1921 {
1922 dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, 1u, shaderStages, nullptr);
1923 }
1924 const auto descriptorSetLayout = dslBuilder.build(vkd, device);
1925
1926 // Pipeline layout.
1927 const auto pipelineLayout = makePipelineLayout(vkd, device, descriptorSetLayout.get());
1928
1929 // Descriptor pool and set.
1930 DescriptorPoolBuilder poolBuilder;
1931 poolBuilder.addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR);
1932 poolBuilder.addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 3u);
1933 if (samplersNeeded(m_params.dataType))
1934 {
1935 poolBuilder.addType(VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE, 2u);
1936 poolBuilder.addType(VK_DESCRIPTOR_TYPE_SAMPLER, 2u);
1937 poolBuilder.addType(VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, 2u);
1938 }
1939 else if (storageImageNeeded(m_params.dataType))
1940 {
1941 poolBuilder.addType(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, 1u);
1942 }
1943 const auto descriptorPool = poolBuilder.build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
1944 const auto descriptorSet = makeDescriptorSet(vkd, device, descriptorPool.get(), descriptorSetLayout.get());
1945
1946 // Update descriptor set.
1947 {
1948 const VkWriteDescriptorSetAccelerationStructureKHR writeASInfo = {
1949 VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR,
1950 nullptr,
1951 1u,
1952 topLevelAccelerationStructure.get()->getPtr(),
1953 };
1954
1955 DescriptorSetUpdateBuilder updateBuilder;
1956
1957 const auto ds = descriptorSet.get();
1958
1959 const auto calleeBufferDescriptorInfo = makeDescriptorBufferInfo(calleeBuffer.get(), 0ull, VK_WHOLE_SIZE);
1960 const auto outputBufferDescriptorInfo = makeDescriptorBufferInfo(outputBuffer.get(), 0ull, VK_WHOLE_SIZE);
1961 const auto inputBufferDescriptorInfo = makeDescriptorBufferInfo(inputBuffer.get(), 0ull, VK_WHOLE_SIZE);
1962
1963 updateBuilder.writeSingle(ds, DescriptorSetUpdateBuilder::Location::binding(0u),
1964 VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &writeASInfo);
1965 updateBuilder.writeSingle(ds, DescriptorSetUpdateBuilder::Location::binding(1u),
1966 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &calleeBufferDescriptorInfo);
1967 updateBuilder.writeSingle(ds, DescriptorSetUpdateBuilder::Location::binding(2u),
1968 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &outputBufferDescriptorInfo);
1969 updateBuilder.writeSingle(ds, DescriptorSetUpdateBuilder::Location::binding(3u),
1970 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &inputBufferDescriptorInfo);
1971
1972 if (samplersNeeded(m_params.dataType))
1973 {
1974 // Update textures, samplers and combined image samplers.
1975 std::vector<VkDescriptorImageInfo> textureDescInfos;
1976 std::vector<VkDescriptorImageInfo> textureSamplerInfos;
1977 std::vector<VkDescriptorImageInfo> combinedSamplerInfos;
1978
1979 for (size_t i = 0; i < kNumAloneImages; ++i)
1980 textureDescInfos.push_back(
1981 makeDescriptorImageInfo(DE_NULL, textureViews[i].get(), VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL));
1982 for (size_t i = 0; i < kNumAloneSamplers; ++i)
1983 textureSamplerInfos.push_back(
1984 makeDescriptorImageInfo(samplers[i].get(), DE_NULL, VK_IMAGE_LAYOUT_UNDEFINED));
1985
1986 for (size_t i = 0; i < kNumCombined; ++i)
1987 combinedSamplerInfos.push_back(makeDescriptorImageInfo(samplers[i + kNumAloneSamplers].get(),
1988 textureViews[i + kNumAloneImages].get(),
1989 VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL));
1990
1991 updateBuilder.writeArray(ds, DescriptorSetUpdateBuilder::Location::binding(4u),
1992 VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE, kNumAloneImages, textureDescInfos.data());
1993 updateBuilder.writeArray(ds, DescriptorSetUpdateBuilder::Location::binding(5u), VK_DESCRIPTOR_TYPE_SAMPLER,
1994 kNumAloneSamplers, textureSamplerInfos.data());
1995 updateBuilder.writeArray(ds, DescriptorSetUpdateBuilder::Location::binding(6u),
1996 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, kNumCombined,
1997 combinedSamplerInfos.data());
1998 }
1999 else if (storageImageNeeded(m_params.dataType))
2000 {
2001 const auto storageImageDescriptorInfo =
2002 makeDescriptorImageInfo(DE_NULL, textureViews.back().get(), VK_IMAGE_LAYOUT_GENERAL);
2003 updateBuilder.writeSingle(ds, DescriptorSetUpdateBuilder::Location::binding(4u),
2004 VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, &storageImageDescriptorInfo);
2005 }
2006
2007 updateBuilder.update(vkd, device);
2008 }
2009
2010 // Create raytracing pipeline and shader binding tables.
2011 Move<VkPipeline> pipeline;
2012
2013 de::MovePtr<BufferWithMemory> raygenShaderBindingTable;
2014 de::MovePtr<BufferWithMemory> missShaderBindingTable;
2015 de::MovePtr<BufferWithMemory> hitShaderBindingTable;
2016 de::MovePtr<BufferWithMemory> callableShaderBindingTable;
2017
2018 VkStridedDeviceAddressRegionKHR raygenShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
2019 VkStridedDeviceAddressRegionKHR missShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
2020 VkStridedDeviceAddressRegionKHR hitShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
2021 VkStridedDeviceAddressRegionKHR callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
2022
2023 {
2024 const auto rayTracingPipeline = de::newMovePtr<RayTracingPipeline>();
2025 const auto callType = m_params.callType;
2026
2027 // Every case uses a ray generation shader.
2028 rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR,
2029 createShaderModule(vkd, device, m_context.getBinaryCollection().get("rgen"), 0),
2030 0);
2031
2032 if (callType == CallType::TRACE_RAY)
2033 {
2034 rayTracingPipeline->addShader(
2035 VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,
2036 createShaderModule(vkd, device, m_context.getBinaryCollection().get("chit"), 0), 1);
2037 }
2038 else if (callType == CallType::EXECUTE_CALLABLE)
2039 {
2040 rayTracingPipeline->addShader(
2041 VK_SHADER_STAGE_CALLABLE_BIT_KHR,
2042 createShaderModule(vkd, device, m_context.getBinaryCollection().get("call"), 0), 1);
2043 }
2044 else if (callType == CallType::REPORT_INTERSECTION)
2045 {
2046 rayTracingPipeline->addShader(
2047 VK_SHADER_STAGE_INTERSECTION_BIT_KHR,
2048 createShaderModule(vkd, device, m_context.getBinaryCollection().get("rint"), 0), 1);
2049 rayTracingPipeline->addShader(
2050 VK_SHADER_STAGE_ANY_HIT_BIT_KHR,
2051 createShaderModule(vkd, device, m_context.getBinaryCollection().get("ahit"), 0), 1);
2052 }
2053 else
2054 {
2055 DE_ASSERT(false);
2056 }
2057
2058 pipeline = rayTracingPipeline->createPipeline(vkd, device, pipelineLayout.get());
2059
2060 raygenShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
2061 vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
2062 raygenShaderBindingTableRegion =
2063 makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0),
2064 shaderGroupHandleSize, shaderGroupHandleSize);
2065
2066 if (callType == CallType::EXECUTE_CALLABLE)
2067 {
2068 callableShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
2069 vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
2070 callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(
2071 getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize,
2072 shaderGroupHandleSize);
2073 }
2074 else if (callType == CallType::TRACE_RAY || callType == CallType::REPORT_INTERSECTION)
2075 {
2076 hitShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
2077 vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
2078 hitShaderBindingTableRegion =
2079 makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0),
2080 shaderGroupHandleSize, shaderGroupHandleSize);
2081 }
2082 else
2083 {
2084 DE_ASSERT(false);
2085 }
2086 }
2087
2088 // Use ray tracing pipeline.
2089 vkd.cmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipeline.get());
2090 vkd.cmdBindDescriptorSets(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipelineLayout.get(), 0u, 1u,
2091 &descriptorSet.get(), 0u, nullptr);
2092 vkd.cmdTraceRaysKHR(cmdBuffer, &raygenShaderBindingTableRegion, &missShaderBindingTableRegion,
2093 &hitShaderBindingTableRegion, &callableShaderBindingTableRegion, 1u, 1u, 1u);
2094
2095 // Synchronize output and callee buffers.
2096 const auto memBarrier = makeMemoryBarrier(VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
2097 vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR, VK_PIPELINE_STAGE_HOST_BIT, 0u, 1u,
2098 &memBarrier, 0u, nullptr, 0u, nullptr);
2099
2100 endCommandBuffer(vkd, cmdBuffer);
2101 submitCommandsAndWait(vkd, device, queue, cmdBuffer);
2102
2103 // Verify output and callee buffers.
2104 invalidateAlloc(vkd, device, outputBufferAlloc);
2105 invalidateAlloc(vkd, device, calleeBufferAlloc);
2106
2107 std::map<std::string, void *> bufferPtrs;
2108 bufferPtrs["output"] = outputBufferPtr;
2109 bufferPtrs["callee"] = calleeBufferPtr;
2110
2111 for (const auto &ptr : bufferPtrs)
2112 {
2113 const auto &bufferName = ptr.first;
2114 const auto &bufferPtr = ptr.second;
2115
2116 uint32_t outputVal;
2117 deMemcpy(&outputVal, bufferPtr, sizeof(outputVal));
2118
2119 if (outputVal != 1u)
2120 return tcu::TestStatus::fail("Unexpected value found in " + bufferName +
2121 " buffer: " + de::toString(outputVal));
2122 }
2123
2124 return tcu::TestStatus::pass("Pass");
2125 }
2126
2127 enum class InterfaceType
2128 {
2129 RAY_PAYLOAD = 0,
2130 CALLABLE_DATA,
2131 HIT_ATTRIBUTES,
2132 SHADER_RECORD_BUFFER_RGEN,
2133 SHADER_RECORD_BUFFER_CALL,
2134 SHADER_RECORD_BUFFER_MISS,
2135 SHADER_RECORD_BUFFER_HIT,
2136 };
2137
2138 // Separate class to ease testing pipeline interface variables.
2139 class DataSpillPipelineInterfaceTestCase : public vkt::TestCase
2140 {
2141 public:
2142 struct TestParams
2143 {
2144 InterfaceType interfaceType;
2145 };
2146
2147 DataSpillPipelineInterfaceTestCase(tcu::TestContext &testCtx, const std::string &name,
2148 const TestParams &testParams);
~DataSpillPipelineInterfaceTestCase(void)2149 virtual ~DataSpillPipelineInterfaceTestCase(void)
2150 {
2151 }
2152
2153 virtual void initPrograms(vk::SourceCollections &programCollection) const;
2154 virtual TestInstance *createInstance(Context &context) const;
2155 virtual void checkSupport(Context &context) const;
2156
2157 private:
2158 TestParams m_params;
2159 };
2160
2161 class DataSpillPipelineInterfaceTestInstance : public vkt::TestInstance
2162 {
2163 public:
2164 using TestParams = DataSpillPipelineInterfaceTestCase::TestParams;
2165
2166 DataSpillPipelineInterfaceTestInstance(Context &context, const TestParams &testParams);
~DataSpillPipelineInterfaceTestInstance(void)2167 ~DataSpillPipelineInterfaceTestInstance(void)
2168 {
2169 }
2170
2171 tcu::TestStatus iterate(void);
2172
2173 private:
2174 TestParams m_params;
2175 };
2176
DataSpillPipelineInterfaceTestCase(tcu::TestContext & testCtx,const std::string & name,const TestParams & testParams)2177 DataSpillPipelineInterfaceTestCase::DataSpillPipelineInterfaceTestCase(tcu::TestContext &testCtx,
2178 const std::string &name,
2179 const TestParams &testParams)
2180 : vkt::TestCase(testCtx, name)
2181 , m_params(testParams)
2182 {
2183 }
2184
createInstance(Context & context) const2185 TestInstance *DataSpillPipelineInterfaceTestCase::createInstance(Context &context) const
2186 {
2187 return new DataSpillPipelineInterfaceTestInstance(context, m_params);
2188 }
2189
DataSpillPipelineInterfaceTestInstance(Context & context,const TestParams & testParams)2190 DataSpillPipelineInterfaceTestInstance::DataSpillPipelineInterfaceTestInstance(Context &context,
2191 const TestParams &testParams)
2192 : vkt::TestInstance(context)
2193 , m_params(testParams)
2194 {
2195 }
2196
checkSupport(Context & context) const2197 void DataSpillPipelineInterfaceTestCase::checkSupport(Context &context) const
2198 {
2199 commonCheckSupport(context);
2200 }
2201
initPrograms(vk::SourceCollections & programCollection) const2202 void DataSpillPipelineInterfaceTestCase::initPrograms(vk::SourceCollections &programCollection) const
2203 {
2204 const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
2205
2206 const std::string glslHeader = "#version 460 core\n"
2207 "#extension GL_EXT_ray_tracing : require\n";
2208
2209 const std::string glslBindings = "layout(set = 0, binding = 0) uniform accelerationStructureEXT topLevelAS;\n"
2210 "layout(set = 0, binding = 1) buffer StorageBlock { uint val[" +
2211 std::to_string(kNumStorageValues) + "]; } storageBuffer;\n";
2212
2213 if (m_params.interfaceType == InterfaceType::RAY_PAYLOAD)
2214 {
2215 // The closest hit shader will store 100 in the second array position.
2216 // The ray gen shader will store 103 in the first array position using the hitValue after the traceRayExt() call.
2217
2218 std::ostringstream rgen;
2219 rgen << glslHeader << "layout(location = 0) rayPayloadEXT vec3 hitValue;\n"
2220 << glslBindings << "void main()\n"
2221 << "{\n"
2222 << " hitValue = vec3(10.0, 30.0, 60.0);\n"
2223 << " traceRayEXT(topLevelAS, 0u, 0xFFu, 0, 0, 0, vec3(0.5, 0.5, 0.0), 0.0, vec3(0.0, 0.0, -1.0), 9.0, "
2224 "0);\n"
2225 << " storageBuffer.val[0] = uint(hitValue.x + hitValue.y + hitValue.z);\n"
2226 << "}\n";
2227 programCollection.glslSources.add("rgen")
2228 << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
2229
2230 std::stringstream chit;
2231 chit << glslHeader << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
2232 << "hitAttributeEXT vec3 attribs;\n"
2233 << glslBindings << "void main()\n"
2234 << "{\n"
2235 << " storageBuffer.val[1] = uint(hitValue.x + hitValue.y + hitValue.z);\n"
2236 << " hitValue = vec3(hitValue.x + 1.0, hitValue.y + 1.0, hitValue.z + 1.0);\n"
2237 << "}\n";
2238 programCollection.glslSources.add("chit")
2239 << glu::ClosestHitSource(updateRayTracingGLSL(chit.str())) << buildOptions;
2240 }
2241 else if (m_params.interfaceType == InterfaceType::CALLABLE_DATA)
2242 {
2243 // The callable shader shader will store 100 in the second array position.
2244 // The ray gen shader will store 200 in the first array position using the callable data after the executeCallableEXT() call.
2245
2246 std::ostringstream rgen;
2247 rgen << glslHeader << "layout(location = 0) callableDataEXT float callableData;\n"
2248 << glslBindings << "void main()\n"
2249 << "{\n"
2250 << " callableData = 100.0;\n"
2251 << " executeCallableEXT(0, 0);\n"
2252 << " storageBuffer.val[0] = uint(callableData);\n"
2253 << "}\n";
2254 programCollection.glslSources.add("rgen")
2255 << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
2256
2257 std::ostringstream call;
2258 call << glslHeader << "layout(location = 0) callableDataInEXT float callableData;\n"
2259 << glslBindings << "void main()\n"
2260 << "{\n"
2261 << " storageBuffer.val[1] = uint(callableData);\n"
2262 << " callableData = callableData * 2.0;\n"
2263 << "}\n";
2264
2265 programCollection.glslSources.add("call")
2266 << glu::CallableSource(updateRayTracingGLSL(call.str())) << buildOptions;
2267 }
2268 else if (m_params.interfaceType == InterfaceType::HIT_ATTRIBUTES)
2269 {
2270 // The ray gen shader will store value 300 in the first storage buffer position.
2271 // The intersection shader will store value 315 in the second storage buffer position.
2272 // The closes hit shader will store value 330 in the third storage buffer position using the hit attributes.
2273
2274 std::ostringstream rgen;
2275 rgen << glslHeader << "layout(location = 0) rayPayloadEXT vec3 hitValue;\n"
2276 << glslBindings << "void main()\n"
2277 << "{\n"
2278 << " traceRayEXT(topLevelAS, 0u, 0xFFu, 0, 0, 0, vec3(0.5, 0.5, 0.0), 0.0, vec3(0.0, 0.0, -1.0), 9.0, "
2279 "0);\n"
2280 << " storageBuffer.val[0] = 300u;\n"
2281 << "}\n";
2282 programCollection.glslSources.add("rgen")
2283 << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
2284
2285 std::stringstream rint;
2286 rint << glslHeader << "hitAttributeEXT vec3 attribs;\n"
2287 << glslBindings << "void main()\n"
2288 << "{\n"
2289 << " attribs = vec3(140.0, 160.0, 30.0);\n"
2290 << " storageBuffer.val[1] = 315u;\n"
2291 << " reportIntersectionEXT(1.0f, 0);\n"
2292 << "}\n";
2293
2294 programCollection.glslSources.add("rint")
2295 << glu::IntersectionSource(updateRayTracingGLSL(rint.str())) << buildOptions;
2296
2297 std::stringstream chit;
2298 chit << glslHeader << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
2299 << "hitAttributeEXT vec3 attribs;\n"
2300 << glslBindings << "void main()\n"
2301 << "{\n"
2302 << " storageBuffer.val[2] = uint(attribs.x + attribs.y + attribs.z);\n"
2303 << "}\n";
2304 programCollection.glslSources.add("chit")
2305 << glu::ClosestHitSource(updateRayTracingGLSL(chit.str())) << buildOptions;
2306 }
2307 else if (m_params.interfaceType == InterfaceType::SHADER_RECORD_BUFFER_RGEN)
2308 {
2309 // The ray gen shader will have a uvec4 in the shader record buffer with contents 400, 401, 402, 403.
2310 // The shader will call a callable shader indicating a position in that vec4 (0, 1, 2, 3). For example, let's use position 1.
2311 // The callable shader will return the indicated position+1 modulo 4, so it will return 2 in our case.
2312 // *After* returning from the callable shader, the raygen shader will use that reply to access position 2 and write a 402 in the first output buffer position.
2313 // The callable shader will store 450 in the second output buffer position.
2314
2315 std::ostringstream rgen;
2316 rgen << glslHeader << "layout(shaderRecordEXT) buffer ShaderRecordStruct {\n"
2317 << " uvec4 info;\n"
2318 << "};\n"
2319 << "layout(location = 0) callableDataEXT uint callableData;\n"
2320 << glslBindings << "void main()\n"
2321 << "{\n"
2322 << " callableData = 1u;"
2323 << " executeCallableEXT(0, 0);\n"
2324 << " if (callableData == 0u) storageBuffer.val[0] = info.x;\n"
2325 << " else if (callableData == 1u) storageBuffer.val[0] = info.y;\n"
2326 << " else if (callableData == 2u) storageBuffer.val[0] = info.z;\n"
2327 << " else if (callableData == 3u) storageBuffer.val[0] = info.w;\n"
2328 << "}\n";
2329 programCollection.glslSources.add("rgen")
2330 << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
2331
2332 std::ostringstream call;
2333 call << glslHeader << "layout(location = 0) callableDataInEXT uint callableData;\n"
2334 << glslBindings << "void main()\n"
2335 << "{\n"
2336 << " storageBuffer.val[1] = 450u;\n"
2337 << " callableData = (callableData + 1u) % 4u;\n"
2338 << "}\n";
2339
2340 programCollection.glslSources.add("call")
2341 << glu::CallableSource(updateRayTracingGLSL(call.str())) << buildOptions;
2342 }
2343 else if (m_params.interfaceType == InterfaceType::SHADER_RECORD_BUFFER_CALL)
2344 {
2345 // Similar to the previous case, with a twist:
2346 // * rgen passes the vector position.
2347 // * call increases that by one.
2348 // * subcall increases again and does the modulo operation, also writing 450 in the third output buffer value.
2349 // * call is the one accessing the vector at the returned position, writing 403 in this case to the second output buffer value.
2350 // * call passes this value back doubled to rgen, which writes it to the first output buffer value (806).
2351
2352 std::ostringstream rgen;
2353 rgen << glslHeader << "layout(location = 0) callableDataEXT uint callableData;\n"
2354 << glslBindings << "void main()\n"
2355 << "{\n"
2356 << " callableData = 1u;\n"
2357 << " executeCallableEXT(0, 0);\n"
2358 << " storageBuffer.val[0] = callableData;\n"
2359 << "}\n";
2360 programCollection.glslSources.add("rgen")
2361 << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
2362
2363 std::ostringstream call;
2364 call << glslHeader << "layout(shaderRecordEXT) buffer ShaderRecordStruct {\n"
2365 << " uvec4 info;\n"
2366 << "};\n"
2367 << "layout(location = 0) callableDataInEXT uint callableDataIn;\n"
2368 << "layout(location = 1) callableDataEXT uint callableDataOut;\n"
2369 << glslBindings << "void main()\n"
2370 << "{\n"
2371 << " callableDataOut = callableDataIn + 1u;\n"
2372 << " executeCallableEXT(1, 1);\n"
2373 << " uint outputBufferValue = 777u;\n"
2374 << " if (callableDataOut == 0u) outputBufferValue = info.x;\n"
2375 << " else if (callableDataOut == 1u) outputBufferValue = info.y;\n"
2376 << " else if (callableDataOut == 2u) outputBufferValue = info.z;\n"
2377 << " else if (callableDataOut == 3u) outputBufferValue = info.w;\n"
2378 << " storageBuffer.val[1] = outputBufferValue;\n"
2379 << " callableDataIn = outputBufferValue * 2u;\n"
2380 << "}\n";
2381
2382 programCollection.glslSources.add("call")
2383 << glu::CallableSource(updateRayTracingGLSL(call.str())) << buildOptions;
2384
2385 std::ostringstream subcall;
2386 subcall << glslHeader << "layout(location = 1) callableDataInEXT uint callableData;\n"
2387 << glslBindings << "void main()\n"
2388 << "{\n"
2389 << " callableData = (callableData + 1u) % 4u;\n"
2390 << " storageBuffer.val[2] = 450u;\n"
2391 << "}\n";
2392
2393 programCollection.glslSources.add("subcall")
2394 << glu::CallableSource(updateRayTracingGLSL(subcall.str())) << buildOptions;
2395 }
2396 else if (m_params.interfaceType == InterfaceType::SHADER_RECORD_BUFFER_MISS ||
2397 m_params.interfaceType == InterfaceType::SHADER_RECORD_BUFFER_HIT)
2398 {
2399 // Similar to the previous one, but the intermediate call shader has been replaced with a miss or closest hit shader.
2400 // The rgen shader will communicate with the miss/chit shader using the ray payload instead of the callable data.
2401 // Also, the initial position will be 2, so it will wrap around in this case. The numbers will also change.
2402
2403 std::ostringstream rgen;
2404 rgen << glslHeader << "layout(location = 0) rayPayloadEXT uint rayPayload;\n"
2405 << glslBindings << "void main()\n"
2406 << "{\n"
2407 << " rayPayload = 2u;\n"
2408 << " traceRayEXT(topLevelAS, 0u, 0xFFu, 0, 0, 0, vec3(0.5, 0.5, 0.0), 0.0, vec3(0.0, 0.0, -1.0), 9.0, "
2409 "0);\n"
2410 << " storageBuffer.val[0] = rayPayload;\n"
2411 << "}\n";
2412 programCollection.glslSources.add("rgen")
2413 << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
2414
2415 std::ostringstream chitOrMiss;
2416 chitOrMiss << glslHeader << "layout(shaderRecordEXT) buffer ShaderRecordStruct {\n"
2417 << " uvec4 info;\n"
2418 << "};\n"
2419 << "layout(location = 0) rayPayloadInEXT uint rayPayload;\n"
2420 << "layout(location = 0) callableDataEXT uint callableData;\n"
2421 << glslBindings << "void main()\n"
2422 << "{\n"
2423 << " callableData = rayPayload + 1u;\n"
2424 << " executeCallableEXT(0, 0);\n"
2425 << " uint outputBufferValue = 777u;\n"
2426 << " if (callableData == 0u) outputBufferValue = info.x;\n"
2427 << " else if (callableData == 1u) outputBufferValue = info.y;\n"
2428 << " else if (callableData == 2u) outputBufferValue = info.z;\n"
2429 << " else if (callableData == 3u) outputBufferValue = info.w;\n"
2430 << " storageBuffer.val[1] = outputBufferValue;\n"
2431 << " rayPayload = outputBufferValue * 3u;\n"
2432 << "}\n";
2433
2434 if (m_params.interfaceType == InterfaceType::SHADER_RECORD_BUFFER_MISS)
2435 programCollection.glslSources.add("miss")
2436 << glu::MissSource(updateRayTracingGLSL(chitOrMiss.str())) << buildOptions;
2437 else if (m_params.interfaceType == InterfaceType::SHADER_RECORD_BUFFER_HIT)
2438 programCollection.glslSources.add("chit")
2439 << glu::ClosestHitSource(updateRayTracingGLSL(chitOrMiss.str())) << buildOptions;
2440 else
2441 DE_ASSERT(false);
2442
2443 std::ostringstream call;
2444 call << glslHeader << "layout(location = 0) callableDataInEXT uint callableData;\n"
2445 << glslBindings << "void main()\n"
2446 << "{\n"
2447 << " storageBuffer.val[2] = 490u;\n"
2448 << " callableData = (callableData + 1u) % 4u;\n"
2449 << "}\n";
2450
2451 programCollection.glslSources.add("call")
2452 << glu::CallableSource(updateRayTracingGLSL(call.str())) << buildOptions;
2453 }
2454 else
2455 {
2456 DE_ASSERT(false);
2457 }
2458 }
2459
getShaderStages(InterfaceType type_)2460 VkShaderStageFlags getShaderStages(InterfaceType type_)
2461 {
2462 VkShaderStageFlags flags = VK_SHADER_STAGE_RAYGEN_BIT_KHR;
2463
2464 switch (type_)
2465 {
2466 case InterfaceType::HIT_ATTRIBUTES:
2467 flags |= VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
2468 // fallthrough.
2469 case InterfaceType::RAY_PAYLOAD:
2470 flags |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
2471 break;
2472 case InterfaceType::CALLABLE_DATA:
2473 case InterfaceType::SHADER_RECORD_BUFFER_RGEN:
2474 case InterfaceType::SHADER_RECORD_BUFFER_CALL:
2475 flags |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
2476 break;
2477 case InterfaceType::SHADER_RECORD_BUFFER_MISS:
2478 flags |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
2479 flags |= VK_SHADER_STAGE_MISS_BIT_KHR;
2480 break;
2481 case InterfaceType::SHADER_RECORD_BUFFER_HIT:
2482 flags |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
2483 flags |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
2484 break;
2485 default:
2486 DE_ASSERT(false);
2487 break;
2488 }
2489
2490 return flags;
2491 }
2492
2493 // Proper stage for generating default geometry.
getShaderStageForGeometry(InterfaceType type_)2494 VkShaderStageFlagBits getShaderStageForGeometry(InterfaceType type_)
2495 {
2496 VkShaderStageFlagBits bits = VK_SHADER_STAGE_FLAG_BITS_MAX_ENUM;
2497
2498 switch (type_)
2499 {
2500 case InterfaceType::HIT_ATTRIBUTES:
2501 bits = VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
2502 break;
2503 case InterfaceType::RAY_PAYLOAD:
2504 bits = VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
2505 break;
2506 case InterfaceType::CALLABLE_DATA:
2507 bits = VK_SHADER_STAGE_CALLABLE_BIT_KHR;
2508 break;
2509 case InterfaceType::SHADER_RECORD_BUFFER_RGEN:
2510 bits = VK_SHADER_STAGE_CALLABLE_BIT_KHR;
2511 break;
2512 case InterfaceType::SHADER_RECORD_BUFFER_CALL:
2513 bits = VK_SHADER_STAGE_CALLABLE_BIT_KHR;
2514 break;
2515 case InterfaceType::SHADER_RECORD_BUFFER_MISS:
2516 bits = VK_SHADER_STAGE_MISS_BIT_KHR;
2517 break;
2518 case InterfaceType::SHADER_RECORD_BUFFER_HIT:
2519 bits = VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
2520 break;
2521 default:
2522 DE_ASSERT(false);
2523 break;
2524 }
2525
2526 DE_ASSERT(bits != VK_SHADER_STAGE_FLAG_BITS_MAX_ENUM);
2527 return bits;
2528 }
2529
createSBTWithShaderRecord(const DeviceInterface & vkd,VkDevice device,vk::Allocator & alloc,VkPipeline pipeline,RayTracingPipeline * rayTracingPipeline,uint32_t shaderGroupHandleSize,uint32_t shaderGroupBaseAlignment,uint32_t firstGroup,uint32_t groupCount,de::MovePtr<BufferWithMemory> & shaderBindingTable,VkStridedDeviceAddressRegionKHR & shaderBindingTableRegion)2530 void createSBTWithShaderRecord(const DeviceInterface &vkd, VkDevice device, vk::Allocator &alloc, VkPipeline pipeline,
2531 RayTracingPipeline *rayTracingPipeline, uint32_t shaderGroupHandleSize,
2532 uint32_t shaderGroupBaseAlignment, uint32_t firstGroup, uint32_t groupCount,
2533 de::MovePtr<BufferWithMemory> &shaderBindingTable,
2534 VkStridedDeviceAddressRegionKHR &shaderBindingTableRegion)
2535 {
2536 const auto alignedSize = de::roundUp(shaderGroupHandleSize + kShaderRecordSize, shaderGroupHandleSize);
2537 shaderBindingTable = rayTracingPipeline->createShaderBindingTable(
2538 vkd, device, pipeline, alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, firstGroup, groupCount, 0u, 0u,
2539 MemoryRequirement::Any, 0u, 0u, kShaderRecordSize);
2540 shaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(
2541 getBufferDeviceAddress(vkd, device, shaderBindingTable->get(), 0), alignedSize, groupCount * alignedSize);
2542
2543 // Fill shader record buffer data.
2544 // Note we will only fill the first shader record after the handle.
2545 const tcu::UVec4 shaderRecordData(400u, 401u, 402u, 403u);
2546 auto &sbtAlloc = shaderBindingTable->getAllocation();
2547 auto *dataPtr = reinterpret_cast<uint8_t *>(sbtAlloc.getHostPtr()) + shaderGroupHandleSize;
2548
2549 DE_STATIC_ASSERT(sizeof(shaderRecordData) == static_cast<size_t>(kShaderRecordSize));
2550 deMemcpy(dataPtr, &shaderRecordData, sizeof(shaderRecordData));
2551 }
2552
iterate(void)2553 tcu::TestStatus DataSpillPipelineInterfaceTestInstance::iterate(void)
2554 {
2555 const auto &vki = m_context.getInstanceInterface();
2556 const auto physicalDevice = m_context.getPhysicalDevice();
2557 const auto &vkd = m_context.getDeviceInterface();
2558 const auto device = m_context.getDevice();
2559 const auto queue = m_context.getUniversalQueue();
2560 const auto familyIndex = m_context.getUniversalQueueFamilyIndex();
2561 auto &alloc = m_context.getDefaultAllocator();
2562 const auto shaderStages = getShaderStages(m_params.interfaceType);
2563
2564 // Command buffer.
2565 const auto cmdPool = makeCommandPool(vkd, device, familyIndex);
2566 const auto cmdBufferPtr = allocateCommandBuffer(vkd, device, cmdPool.get(), VK_COMMAND_BUFFER_LEVEL_PRIMARY);
2567 const auto cmdBuffer = cmdBufferPtr.get();
2568
2569 beginCommandBuffer(vkd, cmdBuffer);
2570
2571 // Storage buffer.
2572 std::array<uint32_t, kNumStorageValues> storageBufferData;
2573 const auto storageBufferSize = de::dataSize(storageBufferData);
2574 const auto storagebufferInfo = makeBufferCreateInfo(storageBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
2575 BufferWithMemory storageBuffer(vkd, device, alloc, storagebufferInfo, MemoryRequirement::HostVisible);
2576
2577 // Zero-out buffer.
2578 auto &storageBufferAlloc = storageBuffer.getAllocation();
2579 auto *storageBufferPtr = storageBufferAlloc.getHostPtr();
2580 deMemset(storageBufferPtr, 0, storageBufferSize);
2581 flushAlloc(vkd, device, storageBufferAlloc);
2582
2583 // Acceleration structures.
2584 de::MovePtr<BottomLevelAccelerationStructure> bottomLevelAccelerationStructure;
2585 de::MovePtr<TopLevelAccelerationStructure> topLevelAccelerationStructure;
2586
2587 bottomLevelAccelerationStructure = makeBottomLevelAccelerationStructure();
2588 bottomLevelAccelerationStructure->setDefaultGeometryData(getShaderStageForGeometry(m_params.interfaceType),
2589 VK_GEOMETRY_NO_DUPLICATE_ANY_HIT_INVOCATION_BIT_KHR);
2590 bottomLevelAccelerationStructure->createAndBuild(vkd, device, cmdBuffer, alloc);
2591
2592 topLevelAccelerationStructure = makeTopLevelAccelerationStructure();
2593 topLevelAccelerationStructure->setInstanceCount(1);
2594 topLevelAccelerationStructure->addInstance(
2595 de::SharedPtr<BottomLevelAccelerationStructure>(bottomLevelAccelerationStructure.release()));
2596 topLevelAccelerationStructure->createAndBuild(vkd, device, cmdBuffer, alloc);
2597
2598 // Get some ray tracing properties.
2599 uint32_t shaderGroupHandleSize = 0u;
2600 uint32_t shaderGroupBaseAlignment = 1u;
2601 {
2602 const auto rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
2603 shaderGroupHandleSize = rayTracingPropertiesKHR->getShaderGroupHandleSize();
2604 shaderGroupBaseAlignment = rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
2605 }
2606
2607 // Descriptor set layout.
2608 DescriptorSetLayoutBuilder dslBuilder;
2609 dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, 1u, shaderStages, nullptr);
2610 dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1u, shaderStages, nullptr); // Callee buffer.
2611 const auto descriptorSetLayout = dslBuilder.build(vkd, device);
2612
2613 // Pipeline layout.
2614 const auto pipelineLayout = makePipelineLayout(vkd, device, descriptorSetLayout.get());
2615
2616 // Descriptor pool and set.
2617 DescriptorPoolBuilder poolBuilder;
2618 poolBuilder.addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR);
2619 poolBuilder.addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
2620 const auto descriptorPool = poolBuilder.build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
2621 const auto descriptorSet = makeDescriptorSet(vkd, device, descriptorPool.get(), descriptorSetLayout.get());
2622
2623 // Update descriptor set.
2624 {
2625 const VkWriteDescriptorSetAccelerationStructureKHR writeASInfo = {
2626 VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR,
2627 nullptr,
2628 1u,
2629 topLevelAccelerationStructure.get()->getPtr(),
2630 };
2631
2632 const auto ds = descriptorSet.get();
2633 const auto storageBufferDescriptorInfo = makeDescriptorBufferInfo(storageBuffer.get(), 0ull, VK_WHOLE_SIZE);
2634
2635 DescriptorSetUpdateBuilder updateBuilder;
2636 updateBuilder.writeSingle(ds, DescriptorSetUpdateBuilder::Location::binding(0u),
2637 VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &writeASInfo);
2638 updateBuilder.writeSingle(ds, DescriptorSetUpdateBuilder::Location::binding(1u),
2639 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &storageBufferDescriptorInfo);
2640 updateBuilder.update(vkd, device);
2641 }
2642
2643 // Create raytracing pipeline and shader binding tables.
2644 const auto interfaceType = m_params.interfaceType;
2645 Move<VkPipeline> pipeline;
2646
2647 de::MovePtr<BufferWithMemory> raygenShaderBindingTable;
2648 de::MovePtr<BufferWithMemory> missShaderBindingTable;
2649 de::MovePtr<BufferWithMemory> hitShaderBindingTable;
2650 de::MovePtr<BufferWithMemory> callableShaderBindingTable;
2651
2652 VkStridedDeviceAddressRegionKHR raygenShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
2653 VkStridedDeviceAddressRegionKHR missShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
2654 VkStridedDeviceAddressRegionKHR hitShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
2655 VkStridedDeviceAddressRegionKHR callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
2656
2657 {
2658 const auto rayTracingPipeline = de::newMovePtr<RayTracingPipeline>();
2659
2660 // Every case uses a ray generation shader.
2661 rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR,
2662 createShaderModule(vkd, device, m_context.getBinaryCollection().get("rgen"), 0),
2663 0);
2664
2665 if (interfaceType == InterfaceType::RAY_PAYLOAD)
2666 {
2667 rayTracingPipeline->addShader(
2668 VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,
2669 createShaderModule(vkd, device, m_context.getBinaryCollection().get("chit"), 0), 1);
2670 }
2671 else if (interfaceType == InterfaceType::CALLABLE_DATA ||
2672 interfaceType == InterfaceType::SHADER_RECORD_BUFFER_RGEN)
2673 {
2674 rayTracingPipeline->addShader(
2675 VK_SHADER_STAGE_CALLABLE_BIT_KHR,
2676 createShaderModule(vkd, device, m_context.getBinaryCollection().get("call"), 0), 1);
2677 }
2678 else if (interfaceType == InterfaceType::HIT_ATTRIBUTES)
2679 {
2680 rayTracingPipeline->addShader(
2681 VK_SHADER_STAGE_INTERSECTION_BIT_KHR,
2682 createShaderModule(vkd, device, m_context.getBinaryCollection().get("rint"), 0), 1);
2683 rayTracingPipeline->addShader(
2684 VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,
2685 createShaderModule(vkd, device, m_context.getBinaryCollection().get("chit"), 0), 1);
2686 }
2687 else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_CALL)
2688 {
2689 rayTracingPipeline->addShader(
2690 VK_SHADER_STAGE_CALLABLE_BIT_KHR,
2691 createShaderModule(vkd, device, m_context.getBinaryCollection().get("call"), 0), 1);
2692 rayTracingPipeline->addShader(
2693 VK_SHADER_STAGE_CALLABLE_BIT_KHR,
2694 createShaderModule(vkd, device, m_context.getBinaryCollection().get("subcall"), 0), 2);
2695 }
2696 else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_MISS)
2697 {
2698 rayTracingPipeline->addShader(
2699 VK_SHADER_STAGE_MISS_BIT_KHR,
2700 createShaderModule(vkd, device, m_context.getBinaryCollection().get("miss"), 0), 1);
2701 rayTracingPipeline->addShader(
2702 VK_SHADER_STAGE_CALLABLE_BIT_KHR,
2703 createShaderModule(vkd, device, m_context.getBinaryCollection().get("call"), 0), 2);
2704 }
2705 else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_HIT)
2706 {
2707 rayTracingPipeline->addShader(
2708 VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,
2709 createShaderModule(vkd, device, m_context.getBinaryCollection().get("chit"), 0), 1);
2710 rayTracingPipeline->addShader(
2711 VK_SHADER_STAGE_CALLABLE_BIT_KHR,
2712 createShaderModule(vkd, device, m_context.getBinaryCollection().get("call"), 0), 2);
2713 }
2714 else
2715 {
2716 DE_ASSERT(false);
2717 }
2718
2719 pipeline = rayTracingPipeline->createPipeline(vkd, device, pipelineLayout.get());
2720
2721 if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_RGEN)
2722 {
2723 createSBTWithShaderRecord(vkd, device, alloc, pipeline.get(), rayTracingPipeline.get(),
2724 shaderGroupHandleSize, shaderGroupBaseAlignment, 0u, 1u, raygenShaderBindingTable,
2725 raygenShaderBindingTableRegion);
2726 }
2727 else
2728 {
2729 raygenShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
2730 vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
2731 raygenShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(
2732 getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0), shaderGroupHandleSize,
2733 shaderGroupHandleSize);
2734 }
2735
2736 if (interfaceType == InterfaceType::CALLABLE_DATA || interfaceType == InterfaceType::SHADER_RECORD_BUFFER_RGEN)
2737 {
2738 callableShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
2739 vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
2740 callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(
2741 getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize,
2742 shaderGroupHandleSize);
2743 }
2744 else if (interfaceType == InterfaceType::RAY_PAYLOAD || interfaceType == InterfaceType::HIT_ATTRIBUTES)
2745 {
2746 hitShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
2747 vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
2748 hitShaderBindingTableRegion =
2749 makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0),
2750 shaderGroupHandleSize, shaderGroupHandleSize);
2751 }
2752 else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_CALL)
2753 {
2754 createSBTWithShaderRecord(vkd, device, alloc, pipeline.get(), rayTracingPipeline.get(),
2755 shaderGroupHandleSize, shaderGroupBaseAlignment, 1u, 2u,
2756 callableShaderBindingTable, callableShaderBindingTableRegion);
2757 }
2758 else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_MISS)
2759 {
2760 createSBTWithShaderRecord(vkd, device, alloc, pipeline.get(), rayTracingPipeline.get(),
2761 shaderGroupHandleSize, shaderGroupBaseAlignment, 1u, 1u, missShaderBindingTable,
2762 missShaderBindingTableRegion);
2763
2764 // Callable shader table.
2765 callableShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
2766 vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
2767 callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(
2768 getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize,
2769 shaderGroupHandleSize);
2770 }
2771 else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_HIT)
2772 {
2773 createSBTWithShaderRecord(vkd, device, alloc, pipeline.get(), rayTracingPipeline.get(),
2774 shaderGroupHandleSize, shaderGroupBaseAlignment, 1u, 1u, hitShaderBindingTable,
2775 hitShaderBindingTableRegion);
2776
2777 // Callable shader table.
2778 callableShaderBindingTable = rayTracingPipeline->createShaderBindingTable(
2779 vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
2780 callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(
2781 getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize,
2782 shaderGroupHandleSize);
2783 }
2784 else
2785 {
2786 DE_ASSERT(false);
2787 }
2788 }
2789
2790 // Use ray tracing pipeline.
2791 vkd.cmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipeline.get());
2792 vkd.cmdBindDescriptorSets(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipelineLayout.get(), 0u, 1u,
2793 &descriptorSet.get(), 0u, nullptr);
2794 vkd.cmdTraceRaysKHR(cmdBuffer, &raygenShaderBindingTableRegion, &missShaderBindingTableRegion,
2795 &hitShaderBindingTableRegion, &callableShaderBindingTableRegion, 1u, 1u, 1u);
2796
2797 // Synchronize output and callee buffers.
2798 const auto memBarrier = makeMemoryBarrier(VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
2799 vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR, VK_PIPELINE_STAGE_HOST_BIT, 0u, 1u,
2800 &memBarrier, 0u, nullptr, 0u, nullptr);
2801
2802 endCommandBuffer(vkd, cmdBuffer);
2803 submitCommandsAndWait(vkd, device, queue, cmdBuffer);
2804
2805 // Verify storage buffer.
2806 invalidateAlloc(vkd, device, storageBufferAlloc);
2807 deMemcpy(storageBufferData.data(), storageBufferPtr, storageBufferSize);
2808
2809 // These values must match what the shaders store.
2810 std::vector<uint32_t> expectedData;
2811 if (interfaceType == InterfaceType::RAY_PAYLOAD)
2812 {
2813 expectedData.push_back(103u);
2814 expectedData.push_back(100u);
2815 }
2816 else if (interfaceType == InterfaceType::CALLABLE_DATA)
2817 {
2818 expectedData.push_back(200u);
2819 expectedData.push_back(100u);
2820 }
2821 else if (interfaceType == InterfaceType::HIT_ATTRIBUTES)
2822 {
2823 expectedData.push_back(300u);
2824 expectedData.push_back(315u);
2825 expectedData.push_back(330u);
2826 }
2827 else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_RGEN)
2828 {
2829 expectedData.push_back(402u);
2830 expectedData.push_back(450u);
2831 }
2832 else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_CALL)
2833 {
2834 expectedData.push_back(806u);
2835 expectedData.push_back(403u);
2836 expectedData.push_back(450u);
2837 }
2838 else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_MISS ||
2839 interfaceType == InterfaceType::SHADER_RECORD_BUFFER_HIT)
2840 {
2841 expectedData.push_back(1200u);
2842 expectedData.push_back(400u);
2843 expectedData.push_back(490u);
2844 }
2845 else
2846 {
2847 DE_ASSERT(false);
2848 }
2849
2850 size_t pos;
2851 for (pos = 0u; pos < expectedData.size(); ++pos)
2852 {
2853 const auto &stored = storageBufferData.at(pos);
2854 const auto &expected = expectedData.at(pos);
2855 if (stored != expected)
2856 {
2857 std::ostringstream msg;
2858 msg << "Unexpected output value found at position " << pos << " (expected " << expected << " but got "
2859 << stored << ")";
2860 return tcu::TestStatus::fail(msg.str());
2861 }
2862 }
2863
2864 // Expect zeros in unused positions, as filled on the host.
2865 for (; pos < storageBufferData.size(); ++pos)
2866 {
2867 const auto &stored = storageBufferData.at(pos);
2868 if (stored != 0u)
2869 {
2870 std::ostringstream msg;
2871 msg << "Unexpected output value found at position " << pos << " (expected 0 but got " << stored << ")";
2872 return tcu::TestStatus::fail(msg.str());
2873 }
2874 }
2875
2876 return tcu::TestStatus::pass("Pass");
2877 }
2878
2879 } // anonymous namespace
2880
createDataSpillTests(tcu::TestContext & testCtx)2881 tcu::TestCaseGroup *createDataSpillTests(tcu::TestContext &testCtx)
2882 {
2883 // Ray tracing tests for data spilling and unspilling around shader calls
2884 de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "data_spill"));
2885
2886 struct
2887 {
2888 CallType callType;
2889 const char *name;
2890 } callTypes[] = {
2891 {CallType::EXECUTE_CALLABLE, "execute_callable"},
2892 {CallType::TRACE_RAY, "trace_ray"},
2893 {CallType::REPORT_INTERSECTION, "report_intersection"},
2894 };
2895
2896 struct
2897 {
2898 DataType dataType;
2899 const char *name;
2900 } dataTypes[] = {
2901 {DataType::INT32, "int32"},
2902 {DataType::UINT32, "uint32"},
2903 {DataType::INT64, "int64"},
2904 {DataType::UINT64, "uint64"},
2905 {DataType::INT16, "int16"},
2906 {DataType::UINT16, "uint16"},
2907 {DataType::INT8, "int8"},
2908 {DataType::UINT8, "uint8"},
2909 {DataType::FLOAT32, "float32"},
2910 {DataType::FLOAT64, "float64"},
2911 {DataType::FLOAT16, "float16"},
2912 {DataType::STRUCT, "struct"},
2913 {DataType::SAMPLER, "sampler"},
2914 {DataType::IMAGE, "image"},
2915 {DataType::SAMPLED_IMAGE, "combined"},
2916 {DataType::PTR_IMAGE, "ptr_image"},
2917 {DataType::PTR_SAMPLER, "ptr_sampler"},
2918 {DataType::PTR_SAMPLED_IMAGE, "ptr_combined"},
2919 {DataType::PTR_TEXEL, "ptr_texel"},
2920 {DataType::OP_NULL, "op_null"},
2921 {DataType::OP_UNDEF, "op_undef"},
2922 };
2923
2924 struct
2925 {
2926 VectorType vectorType;
2927 const char *prefix;
2928 } vectorTypes[] = {
2929 {VectorType::SCALAR, ""}, {VectorType::V2, "v2"}, {VectorType::V3, "v3"},
2930 {VectorType::V4, "v4"}, {VectorType::A5, "a5"},
2931 };
2932
2933 for (int callTypeIdx = 0; callTypeIdx < DE_LENGTH_OF_ARRAY(callTypes); ++callTypeIdx)
2934 {
2935 const auto &entryCallTypes = callTypes[callTypeIdx];
2936
2937 de::MovePtr<tcu::TestCaseGroup> callTypeGroup(new tcu::TestCaseGroup(testCtx, entryCallTypes.name));
2938 for (int dataTypeIdx = 0; dataTypeIdx < DE_LENGTH_OF_ARRAY(dataTypes); ++dataTypeIdx)
2939 {
2940 const auto &entryDataTypes = dataTypes[dataTypeIdx];
2941
2942 for (int vectorTypeIdx = 0; vectorTypeIdx < DE_LENGTH_OF_ARRAY(vectorTypes); ++vectorTypeIdx)
2943 {
2944 const auto &entryVectorTypes = vectorTypes[vectorTypeIdx];
2945
2946 if ((samplersNeeded(entryDataTypes.dataType) || storageImageNeeded(entryDataTypes.dataType) ||
2947 entryDataTypes.dataType == DataType::STRUCT || entryDataTypes.dataType == DataType::OP_NULL ||
2948 entryDataTypes.dataType == DataType::OP_UNDEF) &&
2949 entryVectorTypes.vectorType != VectorType::SCALAR)
2950 {
2951 continue;
2952 }
2953
2954 DataSpillTestCase::TestParams params;
2955 params.callType = entryCallTypes.callType;
2956 params.dataType = entryDataTypes.dataType;
2957 params.vectorType = entryVectorTypes.vectorType;
2958
2959 const auto testName = std::string(entryVectorTypes.prefix) + entryDataTypes.name;
2960
2961 callTypeGroup->addChild(new DataSpillTestCase(testCtx, testName, params));
2962 }
2963 }
2964
2965 group->addChild(callTypeGroup.release());
2966 }
2967
2968 // Pipeline interface tests.
2969 // Test data spilling and unspilling of pipeline interface variables
2970 de::MovePtr<tcu::TestCaseGroup> pipelineInterfaceGroup(new tcu::TestCaseGroup(testCtx, "pipeline_interface"));
2971
2972 struct
2973 {
2974 InterfaceType interfaceType;
2975 const char *name;
2976 } interfaceTypes[] = {
2977 {InterfaceType::RAY_PAYLOAD, "ray_payload"},
2978 {InterfaceType::CALLABLE_DATA, "callable_data"},
2979 {InterfaceType::HIT_ATTRIBUTES, "hit_attributes"},
2980 {InterfaceType::SHADER_RECORD_BUFFER_RGEN, "shader_record_buffer_rgen"},
2981 {InterfaceType::SHADER_RECORD_BUFFER_CALL, "shader_record_buffer_call"},
2982 {InterfaceType::SHADER_RECORD_BUFFER_MISS, "shader_record_buffer_miss"},
2983 {InterfaceType::SHADER_RECORD_BUFFER_HIT, "shader_record_buffer_hit"},
2984 };
2985
2986 for (int idx = 0; idx < DE_LENGTH_OF_ARRAY(interfaceTypes); ++idx)
2987 {
2988 const auto &entry = interfaceTypes[idx];
2989 DataSpillPipelineInterfaceTestCase::TestParams params;
2990
2991 params.interfaceType = entry.interfaceType;
2992
2993 pipelineInterfaceGroup->addChild(new DataSpillPipelineInterfaceTestCase(testCtx, entry.name, params));
2994 }
2995
2996 group->addChild(pipelineInterfaceGroup.release());
2997
2998 return group.release();
2999 }
3000
3001 } // namespace RayTracing
3002 } // namespace vkt
3003