xref: /aosp_15_r20/external/deqp/external/vulkancts/modules/vulkan/subgroups/vktSubgroupsShuffleTests.cpp (revision 35238bce31c2a825756842865a792f8cf7f89930)
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2019, 2021-2023 The Khronos Group Inc.
6  * Copyright (c) 2019 Google Inc.
7  * Copyright (c) 2017 Codeplay Software Ltd.
8  *
9  * Licensed under the Apache License, Version 2.0 (the "License");
10  * you may not use this file except in compliance with the License.
11  * You may obtain a copy of the License at
12  *
13  *      http://www.apache.org/licenses/LICENSE-2.0
14  *
15  * Unless required by applicable law or agreed to in writing, software
16  * distributed under the License is distributed on an "AS IS" BASIS,
17  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18  * See the License for the specific language governing permissions and
19  * limitations under the License.
20  *
21  */ /*!
22  * \file
23  * \brief Subgroups Tests
24  */ /*--------------------------------------------------------------------*/
25 
26 #include "vktSubgroupsShuffleTests.hpp"
27 #include "vktSubgroupsTestsUtils.hpp"
28 
29 #include <string>
30 #include <vector>
31 
32 using namespace tcu;
33 using namespace std;
34 using namespace vk;
35 using namespace vkt;
36 
37 namespace
38 {
39 enum OpType
40 {
41     OPTYPE_SHUFFLE = 0,
42     OPTYPE_SHUFFLE_XOR,
43     OPTYPE_SHUFFLE_UP,
44     OPTYPE_SHUFFLE_DOWN,
45     OPTYPE_ROTATE,
46     OPTYPE_CLUSTERED_ROTATE,
47     OPTYPE_LAST
48 };
49 
50 // For the second arguments of Xor, Up and Down.
51 enum class ArgType
52 {
53     DYNAMIC = 0,
54     DYNAMICALLY_UNIFORM,
55     CONSTANT
56 };
57 
58 struct CaseDefinition
59 {
60     OpType opType;
61     VkShaderStageFlags shaderStage;
62     VkFormat format;
63     de::SharedPtr<bool> geometryPointSizeSupported;
64     bool requiredSubgroupSize;
65     ArgType argType;
66     bool requires8BitUniformBuffer;
67     bool requires16BitUniformBuffer;
68 };
69 
checkVertexPipelineStages(const void * internalData,vector<const void * > datas,uint32_t width,uint32_t)70 static bool checkVertexPipelineStages(const void *internalData, vector<const void *> datas, uint32_t width, uint32_t)
71 {
72     DE_UNREF(internalData);
73 
74     return subgroups::check(datas, width, 1);
75 }
76 
checkComputeOrMesh(const void * internalData,vector<const void * > datas,const uint32_t numWorkgroups[3],const uint32_t localSize[3],uint32_t)77 static bool checkComputeOrMesh(const void *internalData, vector<const void *> datas, const uint32_t numWorkgroups[3],
78                                const uint32_t localSize[3], uint32_t)
79 {
80     DE_UNREF(internalData);
81 
82     return subgroups::checkComputeOrMesh(datas, numWorkgroups, localSize, 1);
83 }
84 
getOpTypeName(OpType opType)85 string getOpTypeName(OpType opType)
86 {
87     switch (opType)
88     {
89     case OPTYPE_SHUFFLE:
90         return "subgroupShuffle";
91     case OPTYPE_SHUFFLE_XOR:
92         return "subgroupShuffleXor";
93     case OPTYPE_SHUFFLE_UP:
94         return "subgroupShuffleUp";
95     case OPTYPE_SHUFFLE_DOWN:
96         return "subgroupShuffleDown";
97     case OPTYPE_ROTATE:
98         return "subgroupRotate";
99     case OPTYPE_CLUSTERED_ROTATE:
100         return "subgroupClusteredRotate";
101     default:
102         TCU_THROW(InternalError, "Unsupported op type");
103     }
104 }
105 
getExtensionForOpType(OpType opType)106 string getExtensionForOpType(OpType opType)
107 {
108     switch (opType)
109     {
110     case OPTYPE_SHUFFLE:
111         return "GL_KHR_shader_subgroup_shuffle";
112     case OPTYPE_SHUFFLE_XOR:
113         return "GL_KHR_shader_subgroup_shuffle";
114     case OPTYPE_SHUFFLE_UP:
115         return "GL_KHR_shader_subgroup_shuffle_relative";
116     case OPTYPE_SHUFFLE_DOWN:
117         return "GL_KHR_shader_subgroup_shuffle_relative";
118     case OPTYPE_ROTATE:
119         return "GL_KHR_shader_subgroup_rotate";
120     case OPTYPE_CLUSTERED_ROTATE:
121         return "GL_KHR_shader_subgroup_rotate";
122     default:
123         TCU_THROW(InternalError, "Unsupported op type");
124     }
125 }
126 
getExtHeader(const CaseDefinition & caseDef)127 string getExtHeader(const CaseDefinition &caseDef)
128 {
129     return string("#extension ") + getExtensionForOpType(caseDef.opType) +
130            ": enable\n"
131            "#extension GL_KHR_shader_subgroup_ballot: enable\n" +
132            subgroups::getAdditionalExtensionForFormat(caseDef.format);
133 }
134 
getPerStageHeadDeclarations(const CaseDefinition & caseDef)135 vector<string> getPerStageHeadDeclarations(const CaseDefinition &caseDef)
136 {
137     const string formatName   = subgroups::getFormatNameForGLSL(caseDef.format);
138     const uint32_t stageCount = subgroups::getStagesCount(caseDef.shaderStage);
139     const bool fragment       = (caseDef.shaderStage & VK_SHADER_STAGE_FRAGMENT_BIT) != 0;
140     const size_t resultSize   = stageCount + (fragment ? 1 : 0);
141     vector<string> result(resultSize, string());
142 
143     for (uint32_t i = 0; i < result.size(); ++i)
144     {
145         const uint32_t binding0 = i;
146         const uint32_t binding1 = stageCount;
147         const uint32_t binding2 = stageCount + 1;
148         const string buffer1    = (i == stageCount) ? "layout(location = 0) out uint result;\n" :
149                                                       "layout(set = 0, binding = " + de::toString(binding0) +
150                                                        ", std430) buffer Buffer1\n"
151                                                           "{\n"
152                                                           "  uint result[];\n"
153                                                           "};\n";
154 
155         const string b2Layout = ((caseDef.argType == ArgType::DYNAMIC) ? "std430" : "std140");
156         const string b2Type   = ((caseDef.argType == ArgType::DYNAMIC) ? "readonly buffer" : "uniform");
157 
158         result[i] = buffer1 + "layout(set = 0, binding = " + de::toString(binding1) +
159                     ", std430) readonly buffer Buffer2\n"
160                     "{\n"
161                     "  " +
162                     formatName +
163                     " data1[];\n"
164                     "};\n"
165                     "layout(set = 0, binding = " +
166                     de::toString(binding2) + ", " + b2Layout + ") " + b2Type +
167                     " Buffer3\n"
168                     "{\n"
169                     "  uint data2[];\n"
170                     "};\n";
171     }
172 
173     return result;
174 }
175 
getFramebufferPerStageHeadDeclarations(const CaseDefinition & caseDef)176 vector<string> getFramebufferPerStageHeadDeclarations(const CaseDefinition &caseDef)
177 {
178     const string formatName   = subgroups::getFormatNameForGLSL(caseDef.format);
179     const uint32_t stageCount = subgroups::getStagesCount(caseDef.shaderStage);
180     vector<string> result(stageCount, string());
181     const auto b2Len = ((caseDef.argType == ArgType::DYNAMIC) ? subgroups::maxSupportedSubgroupSize() : 1u);
182     const string buffer2{"layout(set = 0, binding = 0) uniform Buffer1\n"
183                          "{\n"
184                          "  " +
185                          formatName + " data1[" + de::toString(subgroups::maxSupportedSubgroupSize()) +
186                          "];\n"
187                          "};\n"
188                          "layout(set = 0, binding = 1) uniform Buffer2\n"
189                          "{\n"
190                          "  uint data2[" +
191                          de::toString(b2Len) +
192                          "];\n"
193                          "};\n"};
194 
195     for (size_t i = 0; i < result.size(); ++i)
196     {
197         switch (i)
198         {
199         case 0:
200             result[i] = "layout(location = 0) out float result;\n" + buffer2;
201             break;
202         case 1:
203             result[i] = "layout(location = 0) out float out_color;\n" + buffer2;
204             break;
205         case 2:
206             result[i] = "layout(location = 0) out float out_color[];\n" + buffer2;
207             break;
208         case 3:
209             result[i] = "layout(location = 0) out float out_color;\n" + buffer2;
210             break;
211         default:
212             TCU_THROW(InternalError, "Unknown stage");
213         }
214     }
215 
216     return result;
217 }
218 
getNonClusteredTestSource(const CaseDefinition & caseDef)219 const string getNonClusteredTestSource(const CaseDefinition &caseDef)
220 {
221     const string id = caseDef.opType == OPTYPE_SHUFFLE      ? "id_in" :
222                       caseDef.opType == OPTYPE_SHUFFLE_XOR  ? "gl_SubgroupInvocationID ^ id_in" :
223                       caseDef.opType == OPTYPE_SHUFFLE_UP   ? "gl_SubgroupInvocationID - id_in" :
224                       caseDef.opType == OPTYPE_SHUFFLE_DOWN ? "gl_SubgroupInvocationID + id_in" :
225                       caseDef.opType == OPTYPE_ROTATE ? "(gl_SubgroupInvocationID + id_in) & (gl_SubgroupSize - 1)" :
226                                                         "";
227     const string idInSource =
228         caseDef.argType == ArgType::DYNAMIC ?
229             "data2[gl_SubgroupInvocationID] & (gl_SubgroupSize - 1)" :
230         caseDef.argType == ArgType::DYNAMICALLY_UNIFORM ?
231             (caseDef.opType == OPTYPE_ROTATE ? "data2[0] & (gl_SubgroupSize * 2 - 1)" : "data2[0] % 32") :
232         caseDef.argType == ArgType::CONSTANT ? "5" :
233                                                "";
234     const string testSource = "  uint temp_res;\n"
235                               "  uvec4 mask = subgroupBallot(true);\n"
236                               "  uint id_in = " +
237                               idInSource +
238                               ";\n"
239                               "  " +
240                               subgroups::getFormatNameForGLSL(caseDef.format) +
241                               " op = " + getOpTypeName(caseDef.opType) +
242                               "(data1[gl_SubgroupInvocationID], id_in);\n"
243                               "  uint id = " +
244                               id +
245                               ";\n"
246                               "  if ((id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
247                               "  {\n"
248                               "    temp_res = (op == data1[id]) ? 1 : 0;\n"
249                               "  }\n"
250                               "  else\n"
251                               "  {\n"
252                               "    temp_res = 1; // Invocation we read from was inactive, so we can't verify results!\n"
253                               "  }\n"
254                               "  tempRes = temp_res;\n";
255 
256     return testSource;
257 }
258 
getClusteredTestSource(const CaseDefinition & caseDef)259 const string getClusteredTestSource(const CaseDefinition &caseDef)
260 {
261     const string idInSource = caseDef.argType == ArgType::DYNAMICALLY_UNIFORM ? "data2[0] & (gl_SubgroupSize * 2 - 1)" :
262                               caseDef.argType == ArgType::CONSTANT            ? "5" :
263                                                                                 "";
264     const string testSource =
265         "  uint temp_res = 1;\n"
266         "  uvec4 mask = subgroupBallot(true);\n"
267         "  uint cluster_size;\n"
268         "  for (cluster_size = 1; cluster_size <= gl_SubgroupSize; cluster_size *= 2)\n"
269         "  {\n"
270         "    uint id_in = " +
271         idInSource +
272         ";\n"
273         "    uint cluster_res;\n"
274         "    " +
275         subgroups::getFormatNameForGLSL(caseDef.format) +
276         " data1_val = data1[gl_SubgroupInvocationID];\n"
277         "    " +
278         subgroups::getFormatNameForGLSL(caseDef.format) +
279         " op;\n"
280         "    switch (cluster_size)\n"
281         "    {\n"
282         "      case 1: op = " +
283         getOpTypeName(caseDef.opType) +
284         "(data1_val, id_in, 1u); break;\n"
285         "      case 2: op = " +
286         getOpTypeName(caseDef.opType) +
287         "(data1_val, id_in, 2u); break;\n"
288         "      case 4: op = " +
289         getOpTypeName(caseDef.opType) +
290         "(data1_val, id_in, 4u); break;\n"
291         "      case 8: op = " +
292         getOpTypeName(caseDef.opType) +
293         "(data1_val, id_in, 8u); break;\n"
294         "      case 16: op = " +
295         getOpTypeName(caseDef.opType) +
296         "(data1_val, id_in, 16u); break;\n"
297         "      case 32: op = " +
298         getOpTypeName(caseDef.opType) +
299         "(data1_val, id_in, 32u); break;\n"
300         "      case 64: op = " +
301         getOpTypeName(caseDef.opType) +
302         "(data1_val, id_in, 64u); break;\n"
303         "      case 128: op = " +
304         getOpTypeName(caseDef.opType) +
305         "(data1_val, id_in, 128u); break;\n"
306         "    }\n"
307         "    uint id = ((gl_SubgroupInvocationID + id_in) & (cluster_size - 1)) | (gl_SubgroupInvocationID & "
308         "~(cluster_size - 1));\n"
309         "    if ((id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
310         "    {\n"
311         "      cluster_res = (op == data1[id]) ? 1 : 0;\n"
312         "    }\n"
313         "    else\n"
314         "    {\n"
315         "      cluster_res = 1; // Invocation we read from was inactive, so we can't verify results!\n"
316         "    }\n"
317         "    temp_res = (temp_res & cluster_res);\n"
318         "  }\n"
319         "  tempRes = temp_res;\n";
320 
321     return testSource;
322 }
323 
getTestSource(const CaseDefinition & caseDef)324 const string getTestSource(const CaseDefinition &caseDef)
325 {
326     if (caseDef.opType == OPTYPE_CLUSTERED_ROTATE)
327     {
328         return getClusteredTestSource(caseDef);
329     }
330     return getNonClusteredTestSource(caseDef);
331 }
332 
initFrameBufferPrograms(SourceCollections & programCollection,CaseDefinition caseDef)333 void initFrameBufferPrograms(SourceCollections &programCollection, CaseDefinition caseDef)
334 {
335     const ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, SPIRV_VERSION_1_3, 0u);
336     const string extHeader                = getExtHeader(caseDef);
337     const string testSrc                  = getTestSource(caseDef);
338     const vector<string> headDeclarations = getFramebufferPerStageHeadDeclarations(caseDef);
339     const bool pointSizeSupported         = *caseDef.geometryPointSizeSupported;
340 
341     subgroups::initStdFrameBufferPrograms(programCollection, buildOptions, caseDef.shaderStage, VK_FORMAT_R32_UINT,
342                                           pointSizeSupported, extHeader, testSrc, "", headDeclarations);
343 }
344 
initPrograms(SourceCollections & programCollection,CaseDefinition caseDef)345 void initPrograms(SourceCollections &programCollection, CaseDefinition caseDef)
346 {
347 #ifndef CTS_USES_VULKANSC
348     const bool spirv14required =
349         (isAllRayTracingStages(caseDef.shaderStage) || isAllMeshShadingStages(caseDef.shaderStage));
350 #else
351     const bool spirv14required = false;
352 #endif // CTS_USES_VULKANSC
353     const SpirvVersion spirvVersion = spirv14required ? SPIRV_VERSION_1_4 : SPIRV_VERSION_1_3;
354     const ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, spirvVersion, 0u, spirv14required);
355     const string extHeader                = getExtHeader(caseDef);
356     const string testSrc                  = getTestSource(caseDef);
357     const vector<string> headDeclarations = getPerStageHeadDeclarations(caseDef);
358     const bool pointSizeSupported         = *caseDef.geometryPointSizeSupported;
359 
360     subgroups::initStdPrograms(programCollection, buildOptions, caseDef.shaderStage, VK_FORMAT_R32_UINT,
361                                pointSizeSupported, extHeader, testSrc, "", headDeclarations);
362 }
363 
supportedCheck(Context & context,CaseDefinition caseDef)364 void supportedCheck(Context &context, CaseDefinition caseDef)
365 {
366     if (!subgroups::isSubgroupSupported(context))
367         TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
368 
369     switch (caseDef.opType)
370     {
371     case OPTYPE_SHUFFLE:
372     case OPTYPE_SHUFFLE_XOR:
373         if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_SHUFFLE_BIT))
374         {
375             TCU_THROW(NotSupportedError, "Device does not support subgroup shuffle operations");
376         }
377         break;
378 #ifndef CTS_USES_VULKANSC
379     case OPTYPE_ROTATE:
380         if (!context.getShaderSubgroupRotateFeatures().shaderSubgroupRotate)
381         {
382             TCU_THROW(NotSupportedError, "Device does not support shaderSubgroupRotate");
383         }
384         if (!subgroups::isSubgroupRotateSpecVersionValid(context))
385         {
386             TCU_THROW(NotSupportedError, "VK_KHR_shader_subgroup_rotate is version 1. Need version 2 or higher");
387         }
388         break;
389     case OPTYPE_CLUSTERED_ROTATE:
390         if (!context.getShaderSubgroupRotateFeatures().shaderSubgroupRotateClustered)
391         {
392             TCU_THROW(NotSupportedError, "Device does not support shaderSubgroupRotateClustered");
393         }
394         if (!subgroups::isSubgroupRotateSpecVersionValid(context))
395         {
396             TCU_THROW(NotSupportedError, "VK_KHR_shader_subgroup_rotate is version 1. Need version 2 or higher");
397         }
398         break;
399 #endif // CTS_USES_VULKANSC
400     default:
401         if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_SHUFFLE_RELATIVE_BIT))
402         {
403             TCU_THROW(NotSupportedError, "Device does not support subgroup shuffle relative operations");
404         }
405         break;
406     }
407 
408     if (!subgroups::isFormatSupportedForDevice(context, caseDef.format))
409         TCU_THROW(NotSupportedError, "Device does not support the specified format in subgroup operations");
410 
411     if (caseDef.requires16BitUniformBuffer)
412     {
413         if (!subgroups::is16BitUBOStorageSupported(context))
414         {
415             TCU_THROW(NotSupportedError, "Device does not support the specified format in subgroup operations");
416         }
417     }
418 
419     if (caseDef.requires8BitUniformBuffer)
420     {
421         if (!subgroups::is8BitUBOStorageSupported(context))
422         {
423             TCU_THROW(NotSupportedError, "Device does not support the specified format in subgroup operations");
424         }
425     }
426 
427     if (caseDef.requiredSubgroupSize)
428     {
429         context.requireDeviceFunctionality("VK_EXT_subgroup_size_control");
430 
431 #ifndef CTS_USES_VULKANSC
432         const VkPhysicalDeviceSubgroupSizeControlFeatures &subgroupSizeControlFeatures =
433             context.getSubgroupSizeControlFeatures();
434         const VkPhysicalDeviceSubgroupSizeControlProperties &subgroupSizeControlProperties =
435             context.getSubgroupSizeControlProperties();
436 #else
437         const VkPhysicalDeviceSubgroupSizeControlFeaturesEXT &subgroupSizeControlFeatures =
438             context.getSubgroupSizeControlFeaturesEXT();
439         const VkPhysicalDeviceSubgroupSizeControlPropertiesEXT &subgroupSizeControlProperties =
440             context.getSubgroupSizeControlPropertiesEXT();
441 #endif // CTS_USES_VULKANSC
442 
443         if (subgroupSizeControlFeatures.subgroupSizeControl == false)
444             TCU_THROW(NotSupportedError, "Device does not support varying subgroup sizes nor required subgroup size");
445 
446         if (subgroupSizeControlFeatures.computeFullSubgroups == false)
447             TCU_THROW(NotSupportedError, "Device does not support full subgroups in compute shaders");
448 
449         if ((subgroupSizeControlProperties.requiredSubgroupSizeStages & caseDef.shaderStage) != caseDef.shaderStage)
450             TCU_THROW(NotSupportedError, "Required subgroup size is not supported for shader stage");
451     }
452 
453     *caseDef.geometryPointSizeSupported = subgroups::isTessellationAndGeometryPointSizeSupported(context);
454 
455 #ifndef CTS_USES_VULKANSC
456     if (isAllRayTracingStages(caseDef.shaderStage))
457     {
458         context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
459     }
460     else if (isAllMeshShadingStages(caseDef.shaderStage))
461     {
462         context.requireDeviceCoreFeature(DEVICE_CORE_FEATURE_VERTEX_PIPELINE_STORES_AND_ATOMICS);
463         context.requireDeviceFunctionality("VK_EXT_mesh_shader");
464 
465         if ((caseDef.shaderStage & VK_SHADER_STAGE_TASK_BIT_EXT) != 0u)
466         {
467             const auto &features = context.getMeshShaderFeaturesEXT();
468             if (!features.taskShader)
469                 TCU_THROW(NotSupportedError, "Task shaders not supported");
470         }
471     }
472 #endif // CTS_USES_VULKANSC
473 
474     subgroups::supportedCheckShader(context, caseDef.shaderStage);
475 }
476 
noSSBOtest(Context & context,const CaseDefinition caseDef)477 TestStatus noSSBOtest(Context &context, const CaseDefinition caseDef)
478 {
479     const VkDeviceSize secondBufferSize =
480         ((caseDef.argType == ArgType::DYNAMIC) ? subgroups::maxSupportedSubgroupSize() : 1u);
481     const subgroups::SSBOData inputData[2]{
482         {
483             subgroups::SSBOData::InitializeNonZero, //  InputDataInitializeType initializeType;
484             subgroups::SSBOData::LayoutStd140,      //  InputDataLayoutType layout;
485             caseDef.format,                         //  vk::VkFormat format;
486             subgroups::maxSupportedSubgroupSize(),  //  vk::VkDeviceSize numElements;
487             subgroups::SSBOData::BindingUBO,        //  BindingType bindingType;
488         },
489         {
490             subgroups::SSBOData::InitializeNonZero, //  InputDataInitializeType initializeType;
491             subgroups::SSBOData::LayoutStd140,      //  InputDataLayoutType layout;
492             VK_FORMAT_R32_UINT,                     //  vk::VkFormat format;
493             secondBufferSize,                       //  vk::VkDeviceSize numElements;
494             subgroups::SSBOData::BindingUBO,        //  BindingType bindingType;
495         }};
496 
497     switch (caseDef.shaderStage)
498     {
499     case VK_SHADER_STAGE_VERTEX_BIT:
500         return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, inputData, 2, DE_NULL,
501                                                     checkVertexPipelineStages);
502     case VK_SHADER_STAGE_GEOMETRY_BIT:
503         return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, inputData, 2, DE_NULL,
504                                                       checkVertexPipelineStages);
505     case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT:
506         return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, inputData, 2, DE_NULL,
507                                                                     checkVertexPipelineStages, caseDef.shaderStage);
508     case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT:
509         return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, inputData, 2, DE_NULL,
510                                                                     checkVertexPipelineStages, caseDef.shaderStage);
511     default:
512         TCU_THROW(InternalError, "Unhandled shader stage");
513     }
514 }
515 
test(Context & context,const CaseDefinition caseDef)516 TestStatus test(Context &context, const CaseDefinition caseDef)
517 {
518     const auto secondBufferLayout =
519         ((caseDef.argType == ArgType::DYNAMIC) ? subgroups::SSBOData::LayoutStd430 : subgroups::SSBOData::LayoutStd140);
520     const VkDeviceSize secondBufferElems =
521         ((caseDef.argType == ArgType::DYNAMIC) ? subgroups::maxSupportedSubgroupSize() : 1u);
522     const auto secondBufferType =
523         ((caseDef.argType == ArgType::DYNAMIC) ? subgroups::SSBOData::BindingSSBO : subgroups::SSBOData::BindingUBO);
524 
525     const bool isCompute = isAllComputeStages(caseDef.shaderStage);
526 #ifndef CTS_USES_VULKANSC
527     const bool isMesh = isAllMeshShadingStages(caseDef.shaderStage);
528 #else
529     const bool isMesh = false;
530 #endif // CTS_USES_VULKANSC
531     DE_ASSERT(!(isCompute && isMesh));
532 
533     if (isCompute || isMesh)
534     {
535 #ifndef CTS_USES_VULKANSC
536         const VkPhysicalDeviceSubgroupSizeControlProperties &subgroupSizeControlProperties =
537             context.getSubgroupSizeControlProperties();
538 #else
539         const VkPhysicalDeviceSubgroupSizeControlPropertiesEXT &subgroupSizeControlProperties =
540             context.getSubgroupSizeControlPropertiesEXT();
541 #endif // CTS_USES_VULKANSC
542         TestLog &log = context.getTestContext().getLog();
543         const subgroups::SSBOData inputData[2]{
544             {
545                 subgroups::SSBOData::InitializeNonZero, //  InputDataInitializeType initializeType;
546                 subgroups::SSBOData::LayoutStd430,      //  InputDataLayoutType layout;
547                 caseDef.format,                         //  vk::VkFormat format;
548                 subgroups::maxSupportedSubgroupSize(),  //  vk::VkDeviceSize numElements;
549             },
550             {
551                 subgroups::SSBOData::InitializeNonZero, //  InputDataInitializeType initializeType;
552                 secondBufferLayout,                     //  InputDataLayoutType layout;
553                 VK_FORMAT_R32_UINT,                     //  vk::VkFormat format;
554                 secondBufferElems,                      //  vk::VkDeviceSize numElements;
555                 secondBufferType,
556             },
557         };
558 
559         if (caseDef.requiredSubgroupSize == false)
560         {
561             if (isCompute)
562                 return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, inputData, 2, DE_NULL,
563                                                   checkComputeOrMesh);
564             else
565                 return subgroups::makeMeshTest(context, VK_FORMAT_R32_UINT, inputData, 2, DE_NULL, checkComputeOrMesh);
566         }
567 
568         log << TestLog::Message << "Testing required subgroup size range ["
569             << subgroupSizeControlProperties.minSubgroupSize << ", " << subgroupSizeControlProperties.maxSubgroupSize
570             << "]" << TestLog::EndMessage;
571 
572         // According to the spec, requiredSubgroupSize must be a power-of-two integer.
573         for (uint32_t size = subgroupSizeControlProperties.minSubgroupSize;
574              size <= subgroupSizeControlProperties.maxSubgroupSize; size *= 2)
575         {
576             TestStatus result(QP_TEST_RESULT_INTERNAL_ERROR, "Internal Error");
577 
578             if (isCompute)
579                 result = subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, inputData, 2, DE_NULL,
580                                                     checkComputeOrMesh, size);
581             else
582                 result = subgroups::makeMeshTest(context, VK_FORMAT_R32_UINT, inputData, 2, DE_NULL, checkComputeOrMesh,
583                                                  size);
584 
585             if (result.getCode() != QP_TEST_RESULT_PASS)
586             {
587                 log << TestLog::Message << "subgroupSize " << size << " failed" << TestLog::EndMessage;
588                 return result;
589             }
590         }
591 
592         return TestStatus::pass("OK");
593     }
594     else if (isAllGraphicsStages(caseDef.shaderStage))
595     {
596         const VkShaderStageFlags stages = subgroups::getPossibleGraphicsSubgroupStages(context, caseDef.shaderStage);
597         const subgroups::SSBOData inputData[2]{
598             {
599                 subgroups::SSBOData::InitializeNonZero, //  InputDataInitializeType initializeType;
600                 subgroups::SSBOData::LayoutStd430,      //  InputDataLayoutType layout;
601                 caseDef.format,                         //  vk::VkFormat format;
602                 subgroups::maxSupportedSubgroupSize(),  //  vk::VkDeviceSize numElements;
603                 subgroups::SSBOData::BindingSSBO,       //  bool isImage;
604                 4u,                                     //  uint32_t binding;
605                 stages,                                 //  vk::VkShaderStageFlags stages;
606             },
607             {
608                 subgroups::SSBOData::InitializeNonZero, //  InputDataInitializeType initializeType;
609                 secondBufferLayout,                     //  InputDataLayoutType layout;
610                 VK_FORMAT_R32_UINT,                     //  vk::VkFormat format;
611                 secondBufferElems,                      //  vk::VkDeviceSize numElements;
612                 secondBufferType,                       //  bool isImage;
613                 5u,                                     //  uint32_t binding;
614                 stages,                                 //  vk::VkShaderStageFlags stages;
615             },
616         };
617 
618         return subgroups::allStages(context, VK_FORMAT_R32_UINT, inputData, 2, DE_NULL, checkVertexPipelineStages,
619                                     stages);
620     }
621 #ifndef CTS_USES_VULKANSC
622     else if (isAllRayTracingStages(caseDef.shaderStage))
623     {
624         const VkShaderStageFlags stages = subgroups::getPossibleRayTracingSubgroupStages(context, caseDef.shaderStage);
625         const subgroups::SSBOData inputData[2]{
626             {
627                 subgroups::SSBOData::InitializeNonZero, //  InputDataInitializeType initializeType;
628                 subgroups::SSBOData::LayoutStd430,      //  InputDataLayoutType layout;
629                 caseDef.format,                         //  vk::VkFormat format;
630                 subgroups::maxSupportedSubgroupSize(),  //  vk::VkDeviceSize numElements;
631                 subgroups::SSBOData::BindingSSBO,       //  bool isImage;
632                 6u,                                     //  uint32_t binding;
633                 stages,                                 //  vk::VkShaderStageFlags stages;
634             },
635             {
636                 subgroups::SSBOData::InitializeNonZero, //  InputDataInitializeType initializeType;
637                 secondBufferLayout,                     //  InputDataLayoutType layout;
638                 VK_FORMAT_R32_UINT,                     //  vk::VkFormat format;
639                 secondBufferElems,                      //  vk::VkDeviceSize numElements;
640                 secondBufferType,                       //  bool isImage;
641                 7u,                                     //  uint32_t binding;
642                 stages,                                 //  vk::VkShaderStageFlags stages;
643             },
644         };
645 
646         return subgroups::allRayTracingStages(context, VK_FORMAT_R32_UINT, inputData, 2, DE_NULL,
647                                               checkVertexPipelineStages, stages);
648     }
649 #endif // CTS_USES_VULKANSC
650     else
651         TCU_THROW(InternalError, "Unknown stage or invalid stage set");
652 }
653 } // namespace
654 
655 namespace vkt
656 {
657 namespace subgroups
658 {
createSubgroupsShuffleTests(TestContext & testCtx)659 TestCaseGroup *createSubgroupsShuffleTests(TestContext &testCtx)
660 {
661     de::MovePtr<TestCaseGroup> group(new TestCaseGroup(testCtx, "shuffle"));
662 
663     de::MovePtr<TestCaseGroup> graphicGroup(new TestCaseGroup(testCtx, "graphics"));
664     de::MovePtr<TestCaseGroup> computeGroup(new TestCaseGroup(testCtx, "compute"));
665     de::MovePtr<TestCaseGroup> framebufferGroup(new TestCaseGroup(testCtx, "framebuffer"));
666 #ifndef CTS_USES_VULKANSC
667     de::MovePtr<TestCaseGroup> raytracingGroup(new TestCaseGroup(testCtx, "ray_tracing"));
668     de::MovePtr<TestCaseGroup> meshGroup(new TestCaseGroup(testCtx, "mesh"));
669 #endif // CTS_USES_VULKANSC
670 
671     const VkShaderStageFlags fbStages[] = {
672         VK_SHADER_STAGE_VERTEX_BIT,
673         VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
674         VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
675         VK_SHADER_STAGE_GEOMETRY_BIT,
676     };
677 
678 #ifndef CTS_USES_VULKANSC
679     const VkShaderStageFlags meshStages[] = {
680         VK_SHADER_STAGE_MESH_BIT_EXT,
681         VK_SHADER_STAGE_TASK_BIT_EXT,
682     };
683 #endif // CTS_USES_VULKANSC
684 
685     const bool boolValues[] = {false, true};
686 
687     const struct
688     {
689         ArgType argType;
690         const char *suffix;
691     } argCases[] = {
692         {ArgType::DYNAMIC, ""},
693         {ArgType::DYNAMICALLY_UNIFORM, "_dynamically_uniform"},
694         {ArgType::CONSTANT, "_constant"},
695     };
696 
697     {
698         const vector<VkFormat> formats = subgroups::getAllFormats();
699 
700         for (size_t formatIndex = 0; formatIndex < formats.size(); ++formatIndex)
701         {
702             const VkFormat format           = formats[formatIndex];
703             const string formatName         = subgroups::getFormatNameForGLSL(format);
704             const bool needs8BitUBOStorage  = isFormat8bitTy(format);
705             const bool needs16BitUBOStorage = isFormat16BitTy(format);
706 
707             for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
708             {
709                 for (const auto &argCase : argCases)
710                 {
711                     const OpType opType = static_cast<OpType>(opTypeIndex);
712 
713                     if (opType == OPTYPE_SHUFFLE && argCase.argType != ArgType::DYNAMIC)
714                         continue;
715 
716                     if ((opType == OPTYPE_ROTATE || opType == OPTYPE_CLUSTERED_ROTATE) &&
717                         argCase.argType == ArgType::DYNAMIC)
718                         continue;
719 
720                     const string name = de::toLower(getOpTypeName(opType)) + "_" + formatName + argCase.suffix;
721 
722                     {
723                         const CaseDefinition caseDef = {
724                             opType,                        //  OpType opType;
725                             VK_SHADER_STAGE_ALL_GRAPHICS,  //  VkShaderStageFlags shaderStage;
726                             format,                        //  VkFormat format;
727                             de::SharedPtr<bool>(new bool), //  de::SharedPtr<bool> geometryPointSizeSupported;
728                             false,                         //  bool requiredSubgroupSize;
729                             argCase.argType,               //  ArgType argType;
730                             false,                         //  bool requires8BitUniformBuffer;
731                             false                          //  bool requires16BitUniformBuffer;
732                         };
733 
734                         addFunctionCaseWithPrograms(graphicGroup.get(), name, supportedCheck, initPrograms, test,
735                                                     caseDef);
736                     }
737 
738                     for (size_t groupSizeNdx = 0; groupSizeNdx < DE_LENGTH_OF_ARRAY(boolValues); ++groupSizeNdx)
739                     {
740                         const bool requiredSubgroupSize = boolValues[groupSizeNdx];
741                         const string testName           = name + (requiredSubgroupSize ? "_requiredsubgroupsize" : "");
742                         const CaseDefinition caseDef    = {
743                             opType,                        //  OpType opType;
744                             VK_SHADER_STAGE_COMPUTE_BIT,   //  VkShaderStageFlags shaderStage;
745                             format,                        //  VkFormat format;
746                             de::SharedPtr<bool>(new bool), //  de::SharedPtr<bool> geometryPointSizeSupported;
747                             requiredSubgroupSize,          //  bool requiredSubgroupSize;
748                             argCase.argType,               //  ArgType argType;
749                             false,                         //  bool requires8BitUniformBuffer;
750                             false                          //  bool requires16BitUniformBuffer;
751                         };
752 
753                         addFunctionCaseWithPrograms(computeGroup.get(), testName, supportedCheck, initPrograms, test,
754                                                     caseDef);
755                     }
756 
757 #ifndef CTS_USES_VULKANSC
758                     for (size_t groupSizeNdx = 0; groupSizeNdx < DE_LENGTH_OF_ARRAY(boolValues); ++groupSizeNdx)
759                     {
760                         for (const auto &stage : meshStages)
761                         {
762                             const bool requiredSubgroupSize = boolValues[groupSizeNdx];
763                             const string testName = name + (requiredSubgroupSize ? "_requiredsubgroupsize" : "") + "_" +
764                                                     getShaderStageName(stage);
765                             const CaseDefinition caseDef = {
766                                 opType,                        //  OpType opType;
767                                 stage,                         //  VkShaderStageFlags shaderStage;
768                                 format,                        //  VkFormat format;
769                                 de::SharedPtr<bool>(new bool), //  de::SharedPtr<bool> geometryPointSizeSupported;
770                                 requiredSubgroupSize,          //  bool requiredSubgroupSize;
771                                 argCase.argType,               //  ArgType argType;
772                                 false,                         //  bool requires8BitUniformBuffer;
773                                 false,                         //  bool requires16BitUniformBuffer;
774                             };
775 
776                             addFunctionCaseWithPrograms(meshGroup.get(), testName, supportedCheck, initPrograms, test,
777                                                         caseDef);
778                         }
779                     }
780 #endif // CTS_USES_VULKANSC
781 
782                     for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(fbStages); ++stageIndex)
783                     {
784                         const CaseDefinition caseDef = {
785                             opType,                        //  OpType opType;
786                             fbStages[stageIndex],          //  VkShaderStageFlags shaderStage;
787                             format,                        //  VkFormat format;
788                             de::SharedPtr<bool>(new bool), //  de::SharedPtr<bool> geometryPointSizeSupported;
789                             false,                         //  bool requiredSubgroupSize;
790                             argCase.argType,               //  ArgType argType;
791                             bool(needs8BitUBOStorage),     //  bool requires8BitUniformBuffer;
792                             bool(needs16BitUBOStorage)     //  bool requires16BitUniformBuffer;
793                         };
794                         const string testName = name + "_" + getShaderStageName(caseDef.shaderStage);
795 
796                         addFunctionCaseWithPrograms(framebufferGroup.get(), testName, supportedCheck,
797                                                     initFrameBufferPrograms, noSSBOtest, caseDef);
798                     }
799                 }
800             }
801         }
802     }
803 
804 #ifndef CTS_USES_VULKANSC
805     {
806         const vector<VkFormat> formats = subgroups::getAllRayTracingFormats();
807 
808         for (size_t formatIndex = 0; formatIndex < formats.size(); ++formatIndex)
809         {
810             const VkFormat format   = formats[formatIndex];
811             const string formatName = subgroups::getFormatNameForGLSL(format);
812 
813             for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
814             {
815                 for (const auto &argCase : argCases)
816                 {
817                     const OpType opType = static_cast<OpType>(opTypeIndex);
818 
819                     if (opType == OPTYPE_SHUFFLE && argCase.argType != ArgType::DYNAMIC)
820                         continue;
821 
822                     if ((opType == OPTYPE_ROTATE || opType == OPTYPE_CLUSTERED_ROTATE) &&
823                         argCase.argType == ArgType::DYNAMIC)
824                         continue;
825 
826                     const string name = de::toLower(getOpTypeName(opType)) + "_" + formatName + argCase.suffix;
827                     const CaseDefinition caseDef = {
828                         opType,                        //  OpType opType;
829                         SHADER_STAGE_ALL_RAY_TRACING,  //  VkShaderStageFlags shaderStage;
830                         format,                        //  VkFormat format;
831                         de::SharedPtr<bool>(new bool), //  de::SharedPtr<bool> geometryPointSizeSupported;
832                         false,                         //  bool requiredSubgroupSize;
833                         argCase.argType,               //  ArgType argType;
834                         false,                         //  bool requires8BitUniformBuffer;
835                         false                          //  bool requires16BitUniformBuffer;
836                     };
837 
838                     addFunctionCaseWithPrograms(raytracingGroup.get(), name, supportedCheck, initPrograms, test,
839                                                 caseDef);
840                 }
841             }
842         }
843     }
844 #endif // CTS_USES_VULKANSC
845 
846     group->addChild(graphicGroup.release());
847     group->addChild(computeGroup.release());
848     group->addChild(framebufferGroup.release());
849 #ifndef CTS_USES_VULKANSC
850     group->addChild(raytracingGroup.release());
851     group->addChild(meshGroup.release());
852 #endif // CTS_USES_VULKANSC
853 
854     return group.release();
855 }
856 
857 } // namespace subgroups
858 } // namespace vkt
859