xref: /aosp_15_r20/external/deqp/external/vulkancts/modules/vulkan/subgroups/vktSubgroupsArithmeticTests.cpp (revision 35238bce31c2a825756842865a792f8cf7f89930)
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2019 The Khronos Group Inc.
6  * Copyright (c) 2019 Google Inc.
7  * Copyright (c) 2017 Codeplay Software Ltd.
8  *
9  * Licensed under the Apache License, Version 2.0 (the "License");
10  * you may not use this file except in compliance with the License.
11  * You may obtain a copy of the License at
12  *
13  *      http://www.apache.org/licenses/LICENSE-2.0
14  *
15  * Unless required by applicable law or agreed to in writing, software
16  * distributed under the License is distributed on an "AS IS" BASIS,
17  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18  * See the License for the specific language governing permissions and
19  * limitations under the License.
20  *
21  */ /*!
22  * \file
23  * \brief Subgroups Tests
24  */ /*--------------------------------------------------------------------*/
25 
26 #include "vktSubgroupsArithmeticTests.hpp"
27 #include "vktSubgroupsScanHelpers.hpp"
28 #include "vktSubgroupsTestsUtils.hpp"
29 
30 #include <string>
31 #include <vector>
32 
33 using namespace tcu;
34 using namespace std;
35 using namespace vk;
36 using namespace vkt;
37 
38 namespace
39 {
40 enum OpType
41 {
42     OPTYPE_ADD = 0,
43     OPTYPE_MUL,
44     OPTYPE_MIN,
45     OPTYPE_MAX,
46     OPTYPE_AND,
47     OPTYPE_OR,
48     OPTYPE_XOR,
49     OPTYPE_INCLUSIVE_ADD,
50     OPTYPE_INCLUSIVE_MUL,
51     OPTYPE_INCLUSIVE_MIN,
52     OPTYPE_INCLUSIVE_MAX,
53     OPTYPE_INCLUSIVE_AND,
54     OPTYPE_INCLUSIVE_OR,
55     OPTYPE_INCLUSIVE_XOR,
56     OPTYPE_EXCLUSIVE_ADD,
57     OPTYPE_EXCLUSIVE_MUL,
58     OPTYPE_EXCLUSIVE_MIN,
59     OPTYPE_EXCLUSIVE_MAX,
60     OPTYPE_EXCLUSIVE_AND,
61     OPTYPE_EXCLUSIVE_OR,
62     OPTYPE_EXCLUSIVE_XOR,
63     OPTYPE_LAST
64 };
65 
66 struct CaseDefinition
67 {
68     Operator op;
69     ScanType scanType;
70     VkShaderStageFlags shaderStage;
71     VkFormat format;
72     de::SharedPtr<bool> geometryPointSizeSupported;
73     bool requiredSubgroupSize;
74     bool requires8BitUniformBuffer;
75     bool requires16BitUniformBuffer;
76 };
77 
getOperator(OpType opType)78 static Operator getOperator(OpType opType)
79 {
80     switch (opType)
81     {
82     case OPTYPE_ADD:
83     case OPTYPE_INCLUSIVE_ADD:
84     case OPTYPE_EXCLUSIVE_ADD:
85         return OPERATOR_ADD;
86     case OPTYPE_MUL:
87     case OPTYPE_INCLUSIVE_MUL:
88     case OPTYPE_EXCLUSIVE_MUL:
89         return OPERATOR_MUL;
90     case OPTYPE_MIN:
91     case OPTYPE_INCLUSIVE_MIN:
92     case OPTYPE_EXCLUSIVE_MIN:
93         return OPERATOR_MIN;
94     case OPTYPE_MAX:
95     case OPTYPE_INCLUSIVE_MAX:
96     case OPTYPE_EXCLUSIVE_MAX:
97         return OPERATOR_MAX;
98     case OPTYPE_AND:
99     case OPTYPE_INCLUSIVE_AND:
100     case OPTYPE_EXCLUSIVE_AND:
101         return OPERATOR_AND;
102     case OPTYPE_OR:
103     case OPTYPE_INCLUSIVE_OR:
104     case OPTYPE_EXCLUSIVE_OR:
105         return OPERATOR_OR;
106     case OPTYPE_XOR:
107     case OPTYPE_INCLUSIVE_XOR:
108     case OPTYPE_EXCLUSIVE_XOR:
109         return OPERATOR_XOR;
110     default:
111         DE_FATAL("Unsupported op type");
112         return OPERATOR_ADD;
113     }
114 }
115 
getScanType(OpType opType)116 static ScanType getScanType(OpType opType)
117 {
118     switch (opType)
119     {
120     case OPTYPE_ADD:
121     case OPTYPE_MUL:
122     case OPTYPE_MIN:
123     case OPTYPE_MAX:
124     case OPTYPE_AND:
125     case OPTYPE_OR:
126     case OPTYPE_XOR:
127         return SCAN_REDUCE;
128     case OPTYPE_INCLUSIVE_ADD:
129     case OPTYPE_INCLUSIVE_MUL:
130     case OPTYPE_INCLUSIVE_MIN:
131     case OPTYPE_INCLUSIVE_MAX:
132     case OPTYPE_INCLUSIVE_AND:
133     case OPTYPE_INCLUSIVE_OR:
134     case OPTYPE_INCLUSIVE_XOR:
135         return SCAN_INCLUSIVE;
136     case OPTYPE_EXCLUSIVE_ADD:
137     case OPTYPE_EXCLUSIVE_MUL:
138     case OPTYPE_EXCLUSIVE_MIN:
139     case OPTYPE_EXCLUSIVE_MAX:
140     case OPTYPE_EXCLUSIVE_AND:
141     case OPTYPE_EXCLUSIVE_OR:
142     case OPTYPE_EXCLUSIVE_XOR:
143         return SCAN_EXCLUSIVE;
144     default:
145         DE_FATAL("Unsupported op type");
146         return SCAN_REDUCE;
147     }
148 }
149 
checkVertexPipelineStages(const void * internalData,vector<const void * > datas,uint32_t width,uint32_t)150 static bool checkVertexPipelineStages(const void *internalData, vector<const void *> datas, uint32_t width, uint32_t)
151 {
152     DE_UNREF(internalData);
153 
154     return subgroups::check(datas, width, 0x3);
155 }
156 
checkComputeOrMesh(const void * internalData,vector<const void * > datas,const uint32_t numWorkgroups[3],const uint32_t localSize[3],uint32_t)157 static bool checkComputeOrMesh(const void *internalData, vector<const void *> datas, const uint32_t numWorkgroups[3],
158                                const uint32_t localSize[3], uint32_t)
159 {
160     DE_UNREF(internalData);
161 
162     return subgroups::checkComputeOrMesh(datas, numWorkgroups, localSize, 0x3);
163 }
164 
getOpTypeName(Operator op,ScanType scanType)165 string getOpTypeName(Operator op, ScanType scanType)
166 {
167     return getScanOpName("subgroup", "", op, scanType);
168 }
169 
getExtHeader(const CaseDefinition & caseDef)170 string getExtHeader(const CaseDefinition &caseDef)
171 {
172     return "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
173            "#extension GL_KHR_shader_subgroup_ballot: enable\n" +
174            subgroups::getAdditionalExtensionForFormat(caseDef.format);
175 }
176 
getIndexVars(const CaseDefinition & caseDef)177 string getIndexVars(const CaseDefinition &caseDef)
178 {
179     switch (caseDef.scanType)
180     {
181     case SCAN_REDUCE:
182         return "  uint start = 0, end = gl_SubgroupSize;\n";
183     case SCAN_INCLUSIVE:
184         return "  uint start = 0, end = gl_SubgroupInvocationID + 1;\n";
185     case SCAN_EXCLUSIVE:
186         return "  uint start = 0, end = gl_SubgroupInvocationID;\n";
187     default:
188         TCU_THROW(InternalError, "Unreachable");
189     }
190 }
191 
getTestSrc(const CaseDefinition & caseDef)192 string getTestSrc(const CaseDefinition &caseDef)
193 {
194     const string indexVars = getIndexVars(caseDef);
195 
196     string shader = "  uvec4 mask = subgroupBallot(true);\n" + indexVars + "  " +
197                     subgroups::getFormatNameForGLSL(caseDef.format) +
198                     " ref = " + getIdentity(caseDef.op, caseDef.format) +
199                     ";\n"
200                     "  tempRes = 0;\n"
201                     "  uint identityOnly = 0x3\n;"
202                     "  for (uint index = start; index < end; index++)\n"
203                     "  {\n"
204                     "    if (subgroupBallotBitExtract(mask, index))\n"
205                     "    {\n"
206                     "      ref = " +
207                     getOpOperation(caseDef.op, caseDef.format, "ref", "data[index]") +
208                     ";\n"
209                     "      identityOnly &= ~0x1;\n"
210                     "    }\n"
211                     "  }\n"
212                     "  tempRes = " +
213                     getCompare(caseDef.op, caseDef.format, "ref",
214                                getOpTypeName(caseDef.op, caseDef.scanType) + "(data[gl_SubgroupInvocationID])") +
215                     " ? 0x1 : 0;\n"
216                     "  if (1 == (gl_SubgroupInvocationID % 2))\n"
217                     "  {\n"
218                     "    mask = subgroupBallot(true);\n"
219                     "    ref = " +
220                     getIdentity(caseDef.op, caseDef.format) +
221                     ";\n"
222                     "    for (uint index = start; index < end; index++)\n"
223                     "    {\n"
224                     "      if (subgroupBallotBitExtract(mask, index))\n"
225                     "      {\n"
226                     "        ref = " +
227                     getOpOperation(caseDef.op, caseDef.format, "ref", "data[index]") +
228                     ";\n"
229                     "        identityOnly &= ~0x2;\n"
230                     "      }\n"
231                     "    }\n"
232                     "    tempRes |= " +
233                     getCompare(caseDef.op, caseDef.format, "ref",
234                                getOpTypeName(caseDef.op, caseDef.scanType) + "(data[gl_SubgroupInvocationID])") +
235                     " ? 0x2 : 0;\n"
236                     "  }\n"
237                     "  else\n"
238                     "  {\n"
239                     "    tempRes |= 0x2;\n"
240                     "  }\n";
241 
242     // Can't test max or min identity as they are +/-inf, which the SPIR-V
243     // compiler is allowed to assume don't occur in the program
244     if (caseDef.op == OPERATOR_MIN || caseDef.op == OPERATOR_MAX)
245         shader += "  tempRes |= identityOnly;\n";
246 
247     return shader;
248 }
249 
initFrameBufferPrograms(SourceCollections & programCollection,CaseDefinition caseDef)250 void initFrameBufferPrograms(SourceCollections &programCollection, CaseDefinition caseDef)
251 {
252     const ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, SPIRV_VERSION_1_3, 0u);
253     const string extHeader = getExtHeader(caseDef);
254     const string testSrc   = getTestSrc(caseDef);
255 
256     subgroups::initStdFrameBufferPrograms(programCollection, buildOptions, caseDef.shaderStage, caseDef.format,
257                                           *caseDef.geometryPointSizeSupported, extHeader, testSrc, "");
258 }
259 
initPrograms(SourceCollections & programCollection,CaseDefinition caseDef)260 void initPrograms(SourceCollections &programCollection, CaseDefinition caseDef)
261 {
262 #ifndef CTS_USES_VULKANSC
263     const bool spirv14required =
264         (isAllRayTracingStages(caseDef.shaderStage) || isAllMeshShadingStages(caseDef.shaderStage));
265 #else
266     const bool spirv14required = false;
267 #endif // CTS_USES_VULKANSC
268     const SpirvVersion spirvVersion = spirv14required ? SPIRV_VERSION_1_4 : SPIRV_VERSION_1_3;
269     const ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, spirvVersion, 0u, spirv14required);
270     const string extHeader = getExtHeader(caseDef);
271     const string testSrc   = getTestSrc(caseDef);
272 
273     subgroups::initStdPrograms(programCollection, buildOptions, caseDef.shaderStage, caseDef.format,
274                                *caseDef.geometryPointSizeSupported, extHeader, testSrc, "");
275 }
276 
supportedCheck(Context & context,CaseDefinition caseDef)277 void supportedCheck(Context &context, CaseDefinition caseDef)
278 {
279     if (!subgroups::isSubgroupSupported(context))
280         TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
281 
282     if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_ARITHMETIC_BIT))
283         TCU_THROW(NotSupportedError, "Device does not support subgroup arithmetic operations");
284 
285     if (!subgroups::isFormatSupportedForDevice(context, caseDef.format))
286         TCU_THROW(NotSupportedError, "Device does not support the specified format in subgroup operations");
287 
288     if (caseDef.requires16BitUniformBuffer)
289     {
290         if (!subgroups::is16BitUBOStorageSupported(context))
291         {
292             TCU_THROW(NotSupportedError, "Device does not support the specified format in subgroup operations");
293         }
294     }
295 
296     if (caseDef.requires8BitUniformBuffer)
297     {
298         if (!subgroups::is8BitUBOStorageSupported(context))
299         {
300             TCU_THROW(NotSupportedError, "Device does not support the specified format in subgroup operations");
301         }
302     }
303 
304     if (caseDef.requiredSubgroupSize)
305     {
306         context.requireDeviceFunctionality("VK_EXT_subgroup_size_control");
307 
308 #ifndef CTS_USES_VULKANSC
309         const VkPhysicalDeviceSubgroupSizeControlFeatures &subgroupSizeControlFeatures =
310             context.getSubgroupSizeControlFeatures();
311         const VkPhysicalDeviceSubgroupSizeControlProperties &subgroupSizeControlProperties =
312             context.getSubgroupSizeControlProperties();
313 #else
314         const VkPhysicalDeviceSubgroupSizeControlFeaturesEXT &subgroupSizeControlFeatures =
315             context.getSubgroupSizeControlFeaturesEXT();
316         const VkPhysicalDeviceSubgroupSizeControlPropertiesEXT &subgroupSizeControlProperties =
317             context.getSubgroupSizeControlPropertiesEXT();
318 #endif // CTS_USES_VULKANSC
319 
320         if (subgroupSizeControlFeatures.subgroupSizeControl == false)
321             TCU_THROW(NotSupportedError, "Device does not support varying subgroup sizes nor required subgroup size");
322 
323         if (subgroupSizeControlFeatures.computeFullSubgroups == false)
324             TCU_THROW(NotSupportedError, "Device does not support full subgroups in compute shaders");
325 
326         if ((subgroupSizeControlProperties.requiredSubgroupSizeStages & caseDef.shaderStage) != caseDef.shaderStage)
327             TCU_THROW(NotSupportedError, "Required subgroup size is not supported for shader stage");
328     }
329 
330     *caseDef.geometryPointSizeSupported = subgroups::isTessellationAndGeometryPointSizeSupported(context);
331 
332 #ifndef CTS_USES_VULKANSC
333     if (isAllRayTracingStages(caseDef.shaderStage))
334     {
335         context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
336     }
337     else if (isAllMeshShadingStages(caseDef.shaderStage))
338     {
339         context.requireDeviceCoreFeature(DEVICE_CORE_FEATURE_VERTEX_PIPELINE_STORES_AND_ATOMICS);
340         context.requireDeviceFunctionality("VK_EXT_mesh_shader");
341 
342         if ((caseDef.shaderStage & VK_SHADER_STAGE_TASK_BIT_EXT) != 0u)
343         {
344             const auto &features = context.getMeshShaderFeaturesEXT();
345             if (!features.taskShader)
346                 TCU_THROW(NotSupportedError, "Task shaders not supported");
347         }
348     }
349 #endif // CTS_USES_VULKANSC
350 
351     subgroups::supportedCheckShader(context, caseDef.shaderStage);
352 }
353 
noSSBOtest(Context & context,const CaseDefinition caseDef)354 TestStatus noSSBOtest(Context &context, const CaseDefinition caseDef)
355 {
356     const subgroups::SSBOData inputData = {
357         subgroups::SSBOData::InitializeNonZero, //  InputDataInitializeType initializeType;
358         subgroups::SSBOData::LayoutStd140,      //  InputDataLayoutType layout;
359         caseDef.format,                         //  vk::VkFormat format;
360         subgroups::maxSupportedSubgroupSize(),  //  vk::VkDeviceSize numElements;
361         subgroups::SSBOData::BindingUBO,        //  BindingType bindingType;
362     };
363 
364     switch (caseDef.shaderStage)
365     {
366     case VK_SHADER_STAGE_VERTEX_BIT:
367         return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
368                                                     checkVertexPipelineStages);
369     case VK_SHADER_STAGE_GEOMETRY_BIT:
370         return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
371                                                       checkVertexPipelineStages);
372     case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT:
373         return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
374                                                                     checkVertexPipelineStages, caseDef.shaderStage);
375     case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT:
376         return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
377                                                                     checkVertexPipelineStages, caseDef.shaderStage);
378     default:
379         TCU_THROW(InternalError, "Unhandled shader stage");
380     }
381 }
382 
test(Context & context,const CaseDefinition caseDef)383 TestStatus test(Context &context, const CaseDefinition caseDef)
384 {
385     const bool isCompute = isAllComputeStages(caseDef.shaderStage);
386 #ifndef CTS_USES_VULKANSC
387     const bool isMesh = isAllMeshShadingStages(caseDef.shaderStage);
388 #else
389     const bool isMesh = false;
390 #endif // CTS_USES_VULKANSC
391     DE_ASSERT(!(isCompute && isMesh));
392 
393     if (isCompute || isMesh)
394     {
395 #ifndef CTS_USES_VULKANSC
396         const VkPhysicalDeviceSubgroupSizeControlProperties &subgroupSizeControlProperties =
397             context.getSubgroupSizeControlProperties();
398 #else
399         const VkPhysicalDeviceSubgroupSizeControlPropertiesEXT &subgroupSizeControlProperties =
400             context.getSubgroupSizeControlPropertiesEXT();
401 #endif // CTS_USES_VULKANSC
402         TestLog &log                        = context.getTestContext().getLog();
403         const subgroups::SSBOData inputData = {
404             subgroups::SSBOData::InitializeNonZero, //  InputDataInitializeType initializeType;
405             subgroups::SSBOData::LayoutStd430,      //  InputDataLayoutType layout;
406             caseDef.format,                         //  vk::VkFormat format;
407             subgroups::maxSupportedSubgroupSize(),  //  vk::VkDeviceSize numElements;
408         };
409 
410         if (caseDef.requiredSubgroupSize == false)
411         {
412             if (isMesh)
413                 return subgroups::makeMeshTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkComputeOrMesh);
414             else
415                 return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
416                                                   checkComputeOrMesh);
417         }
418 
419         log << TestLog::Message << "Testing required subgroup size range ["
420             << subgroupSizeControlProperties.minSubgroupSize << ", " << subgroupSizeControlProperties.maxSubgroupSize
421             << "]" << TestLog::EndMessage;
422 
423         // According to the spec, requiredSubgroupSize must be a power-of-two integer.
424         for (uint32_t size = subgroupSizeControlProperties.minSubgroupSize;
425              size <= subgroupSizeControlProperties.maxSubgroupSize; size *= 2)
426         {
427             TestStatus result(QP_TEST_RESULT_INTERNAL_ERROR, "Internal Error");
428 
429             if (isCompute)
430                 result = subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
431                                                     checkComputeOrMesh, size);
432             else
433                 result = subgroups::makeMeshTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
434                                                  checkComputeOrMesh, size);
435 
436             if (result.getCode() != QP_TEST_RESULT_PASS)
437             {
438                 log << TestLog::Message << "subgroupSize " << size << " failed" << TestLog::EndMessage;
439                 return result;
440             }
441         }
442 
443         return TestStatus::pass("OK");
444     }
445     else if (isAllGraphicsStages(caseDef.shaderStage))
446     {
447         const VkShaderStageFlags stages = subgroups::getPossibleGraphicsSubgroupStages(context, caseDef.shaderStage);
448         const subgroups::SSBOData inputData = {
449             subgroups::SSBOData::InitializeNonZero, //  InputDataInitializeType initializeType;
450             subgroups::SSBOData::LayoutStd430,      //  InputDataLayoutType layout;
451             caseDef.format,                         //  vk::VkFormat format;
452             subgroups::maxSupportedSubgroupSize(),  //  vk::VkDeviceSize numElements;
453             subgroups::SSBOData::BindingSSBO,       //  bool isImage;
454             4u,                                     //  uint32_t binding;
455             stages,                                 //  vk::VkShaderStageFlags stages;
456         };
457 
458         return subgroups::allStages(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages,
459                                     stages);
460     }
461 #ifndef CTS_USES_VULKANSC
462     else if (isAllRayTracingStages(caseDef.shaderStage))
463     {
464         const VkShaderStageFlags stages = subgroups::getPossibleRayTracingSubgroupStages(context, caseDef.shaderStage);
465         const subgroups::SSBOData inputData = {
466             subgroups::SSBOData::InitializeNonZero, //  InputDataInitializeType initializeType;
467             subgroups::SSBOData::LayoutStd430,      //  InputDataLayoutType layout;
468             caseDef.format,                         //  vk::VkFormat format;
469             subgroups::maxSupportedSubgroupSize(),  //  vk::VkDeviceSize numElements;
470             subgroups::SSBOData::BindingSSBO,       //  bool isImage;
471             6u,                                     //  uint32_t binding;
472             stages,                                 //  vk::VkShaderStageFlags stages;
473         };
474 
475         return subgroups::allRayTracingStages(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
476                                               checkVertexPipelineStages, stages);
477     }
478 #endif // CTS_USES_VULKANSC
479     else
480         TCU_THROW(InternalError, "Unknown stage or invalid stage set");
481 }
482 } // namespace
483 
484 namespace vkt
485 {
486 namespace subgroups
487 {
createSubgroupsArithmeticTests(TestContext & testCtx)488 TestCaseGroup *createSubgroupsArithmeticTests(TestContext &testCtx)
489 {
490     de::MovePtr<TestCaseGroup> group(new TestCaseGroup(testCtx, "arithmetic"));
491 
492     de::MovePtr<TestCaseGroup> graphicGroup(new TestCaseGroup(testCtx, "graphics"));
493     de::MovePtr<TestCaseGroup> computeGroup(new TestCaseGroup(testCtx, "compute"));
494     de::MovePtr<TestCaseGroup> framebufferGroup(new TestCaseGroup(testCtx, "framebuffer"));
495 #ifndef CTS_USES_VULKANSC
496     de::MovePtr<TestCaseGroup> raytracingGroup(new TestCaseGroup(testCtx, "ray_tracing"));
497     de::MovePtr<TestCaseGroup> meshGroup(new TestCaseGroup(testCtx, "mesh"));
498 #endif // CTS_USES_VULKANSC
499 
500     const VkShaderStageFlags fbStages[] = {
501         VK_SHADER_STAGE_VERTEX_BIT,
502         VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
503         VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
504         VK_SHADER_STAGE_GEOMETRY_BIT,
505     };
506 #ifndef CTS_USES_VULKANSC
507     const VkShaderStageFlags meshStages[] = {
508         VK_SHADER_STAGE_MESH_BIT_EXT,
509         VK_SHADER_STAGE_TASK_BIT_EXT,
510     };
511 #endif // CTS_USES_VULKANSC
512     const bool boolValues[] = {false, true};
513 
514     {
515         const vector<VkFormat> formats = subgroups::getAllFormats();
516 
517         for (size_t formatIndex = 0; formatIndex < formats.size(); ++formatIndex)
518         {
519             const VkFormat format           = formats[formatIndex];
520             const string formatName         = subgroups::getFormatNameForGLSL(format);
521             const bool isBool               = subgroups::isFormatBool(format);
522             const bool isFloat              = subgroups::isFormatFloat(format);
523             const bool needs8BitUBOStorage  = isFormat8bitTy(format);
524             const bool needs16BitUBOStorage = isFormat16BitTy(format);
525 
526             for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
527             {
528                 const OpType opType    = static_cast<OpType>(opTypeIndex);
529                 const Operator op      = getOperator(opType);
530                 const ScanType st      = getScanType(opType);
531                 const bool isBitwiseOp = (op == OPERATOR_AND || op == OPERATOR_OR || op == OPERATOR_XOR);
532 
533                 // Skip float with bitwise category.
534                 if (isFloat && isBitwiseOp)
535                     continue;
536 
537                 // Skip bool when its not the bitwise category.
538                 if (isBool && !isBitwiseOp)
539                     continue;
540 
541                 const string name = de::toLower(getOpTypeName(op, st)) + "_" + formatName;
542 
543                 for (size_t groupSizeNdx = 0; groupSizeNdx < DE_LENGTH_OF_ARRAY(boolValues); ++groupSizeNdx)
544                 {
545                     const bool requiredSubgroupSize = boolValues[groupSizeNdx];
546                     const string testName           = name + (requiredSubgroupSize ? "_requiredsubgroupsize" : "");
547                     const CaseDefinition caseDef    = {
548                         op,                            //  Operator op;
549                         st,                            //  ScanType scanType;
550                         VK_SHADER_STAGE_COMPUTE_BIT,   //  VkShaderStageFlags shaderStage;
551                         format,                        //  VkFormat format;
552                         de::SharedPtr<bool>(new bool), //  de::SharedPtr<bool> geometryPointSizeSupported;
553                         requiredSubgroupSize,          //  bool requiredSubgroupSize;
554                         false,                         //  bool requires8BitUniformBuffer;
555                         false                          //  bool requires16BitUniformBuffer;
556                     };
557 
558                     addFunctionCaseWithPrograms(computeGroup.get(), testName, supportedCheck, initPrograms, test,
559                                                 caseDef);
560                 }
561 
562 #ifndef CTS_USES_VULKANSC
563                 for (size_t groupSizeNdx = 0; groupSizeNdx < DE_LENGTH_OF_ARRAY(boolValues); ++groupSizeNdx)
564                 {
565                     for (const auto &meshStage : meshStages)
566                     {
567                         const bool requiredSubgroupSize = boolValues[groupSizeNdx];
568                         const string testName           = name + "_" + getShaderStageName(meshStage) +
569                                                 (requiredSubgroupSize ? "_requiredsubgroupsize" : "");
570                         const CaseDefinition caseDef = {
571                             op,                            //  Operator op;
572                             st,                            //  ScanType scanType;
573                             meshStage,                     //  VkShaderStageFlags shaderStage;
574                             format,                        //  VkFormat format;
575                             de::SharedPtr<bool>(new bool), //  de::SharedPtr<bool> geometryPointSizeSupported;
576                             requiredSubgroupSize,          //  bool requiredSubgroupSize;
577                             false,                         //  bool requires8BitUniformBuffer;
578                             false                          //  bool requires16BitUniformBuffer;
579                         };
580 
581                         addFunctionCaseWithPrograms(meshGroup.get(), testName, supportedCheck, initPrograms, test,
582                                                     caseDef);
583                     }
584                 }
585 #endif // CTS_USES_VULKANSC
586 
587                 {
588                     const CaseDefinition caseDef = {
589                         op,                            //  Operator op;
590                         st,                            //  ScanType scanType;
591                         VK_SHADER_STAGE_ALL_GRAPHICS,  //  VkShaderStageFlags shaderStage;
592                         format,                        //  VkFormat format;
593                         de::SharedPtr<bool>(new bool), //  de::SharedPtr<bool> geometryPointSizeSupported;
594                         false,                         //  bool requiredSubgroupSize;
595                         false,                         //  bool requires8BitUniformBuffer;
596                         false                          //  bool requires16BitUniformBuffer;
597                     };
598 
599                     addFunctionCaseWithPrograms(graphicGroup.get(), name, supportedCheck, initPrograms, test, caseDef);
600                 }
601 
602                 for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(fbStages); ++stageIndex)
603                 {
604                     const CaseDefinition caseDef = {
605                         op,                            //  Operator op;
606                         st,                            //  ScanType scanType;
607                         fbStages[stageIndex],          //  VkShaderStageFlags shaderStage;
608                         format,                        //  VkFormat format;
609                         de::SharedPtr<bool>(new bool), //  de::SharedPtr<bool> geometryPointSizeSupported;
610                         false,                         //  bool requiredSubgroupSize;
611                         bool(needs8BitUBOStorage),     //  bool requires8BitUniformBuffer;
612                         bool(needs16BitUBOStorage)     //  bool requires16BitUniformBuffer;
613                     };
614                     const string testName = name + "_" + getShaderStageName(caseDef.shaderStage);
615 
616                     addFunctionCaseWithPrograms(framebufferGroup.get(), testName, supportedCheck,
617                                                 initFrameBufferPrograms, noSSBOtest, caseDef);
618                 }
619             }
620         }
621     }
622 
623 #ifndef CTS_USES_VULKANSC
624     {
625         const vector<VkFormat> formats = subgroups::getAllRayTracingFormats();
626 
627         for (size_t formatIndex = 0; formatIndex < formats.size(); ++formatIndex)
628         {
629             const VkFormat format   = formats[formatIndex];
630             const string formatName = subgroups::getFormatNameForGLSL(format);
631             const bool isBool       = subgroups::isFormatBool(format);
632             const bool isFloat      = subgroups::isFormatFloat(format);
633 
634             for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
635             {
636                 const OpType opType    = static_cast<OpType>(opTypeIndex);
637                 const Operator op      = getOperator(opType);
638                 const ScanType st      = getScanType(opType);
639                 const bool isBitwiseOp = (op == OPERATOR_AND || op == OPERATOR_OR || op == OPERATOR_XOR);
640 
641                 // Skip float with bitwise category.
642                 if (isFloat && isBitwiseOp)
643                     continue;
644 
645                 // Skip bool when its not the bitwise category.
646                 if (isBool && !isBitwiseOp)
647                     continue;
648 
649                 {
650                     const CaseDefinition caseDef = {
651                         op,                            //  Operator op;
652                         st,                            //  ScanType scanType;
653                         SHADER_STAGE_ALL_RAY_TRACING,  //  VkShaderStageFlags shaderStage;
654                         format,                        //  VkFormat format;
655                         de::SharedPtr<bool>(new bool), //  de::SharedPtr<bool> geometryPointSizeSupported;
656                         false,                         //  bool requiredSubgroupSize;
657                         false,                         //  bool requires8BitUniformBuffer;
658                         false                          //  bool requires16BitUniformBuffer;
659                     };
660                     const string name = de::toLower(getOpTypeName(op, st)) + "_" + formatName;
661 
662                     addFunctionCaseWithPrograms(raytracingGroup.get(), name, supportedCheck, initPrograms, test,
663                                                 caseDef);
664                 }
665             }
666         }
667     }
668 #endif // CTS_USES_VULKANSC
669 
670     group->addChild(graphicGroup.release());
671     group->addChild(computeGroup.release());
672     group->addChild(framebufferGroup.release());
673 #ifndef CTS_USES_VULKANSC
674     group->addChild(raytracingGroup.release());
675     group->addChild(meshGroup.release());
676 #endif // CTS_USES_VULKANSC
677 
678     return group.release();
679 }
680 } // namespace subgroups
681 } // namespace vkt
682