1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2019 The Khronos Group Inc.
6  * Copyright (c) 2019 Google Inc.
7  * Copyright (c) 2017 Codeplay Software Ltd.
8  * Copyright (c) 2018 NVIDIA Corporation
9  *
10  * Licensed under the Apache License, Version 2.0 (the "License");
11  * you may not use this file except in compliance with the License.
12  * You may obtain a copy of the License at
13  *
14  *      http://www.apache.org/licenses/LICENSE-2.0
15  *
16  * Unless required by applicable law or agreed to in writing, software
17  * distributed under the License is distributed on an "AS IS" BASIS,
18  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19  * See the License for the specific language governing permissions and
20  * limitations under the License.
21  *
22  */ /*!
23  * \file
24  * \brief Subgroups Tests
25  */ /*--------------------------------------------------------------------*/
26 
27 #include "vktSubgroupsPartitionedTests.hpp"
28 #include "vktSubgroupsScanHelpers.hpp"
29 #include "vktSubgroupsTestsUtils.hpp"
30 
31 #include <string>
32 #include <vector>
33 
34 using namespace tcu;
35 using namespace std;
36 using namespace vk;
37 using namespace vkt;
38 
39 namespace
40 {
41 enum OpType
42 {
43     OPTYPE_ADD = 0,
44     OPTYPE_MUL,
45     OPTYPE_MIN,
46     OPTYPE_MAX,
47     OPTYPE_AND,
48     OPTYPE_OR,
49     OPTYPE_XOR,
50     OPTYPE_INCLUSIVE_ADD,
51     OPTYPE_INCLUSIVE_MUL,
52     OPTYPE_INCLUSIVE_MIN,
53     OPTYPE_INCLUSIVE_MAX,
54     OPTYPE_INCLUSIVE_AND,
55     OPTYPE_INCLUSIVE_OR,
56     OPTYPE_INCLUSIVE_XOR,
57     OPTYPE_EXCLUSIVE_ADD,
58     OPTYPE_EXCLUSIVE_MUL,
59     OPTYPE_EXCLUSIVE_MIN,
60     OPTYPE_EXCLUSIVE_MAX,
61     OPTYPE_EXCLUSIVE_AND,
62     OPTYPE_EXCLUSIVE_OR,
63     OPTYPE_EXCLUSIVE_XOR,
64     OPTYPE_LAST
65 };
66 
67 struct CaseDefinition
68 {
69     Operator op;
70     ScanType scanType;
71     VkShaderStageFlags shaderStage;
72     VkFormat format;
73     de::SharedPtr<bool> geometryPointSizeSupported;
74     bool requiredSubgroupSize;
75     bool requires8BitUniformBuffer;
76     bool requires16BitUniformBuffer;
77 };
78 
getOperator(OpType opType)79 static Operator getOperator(OpType opType)
80 {
81     switch (opType)
82     {
83     case OPTYPE_ADD:
84     case OPTYPE_INCLUSIVE_ADD:
85     case OPTYPE_EXCLUSIVE_ADD:
86         return OPERATOR_ADD;
87     case OPTYPE_MUL:
88     case OPTYPE_INCLUSIVE_MUL:
89     case OPTYPE_EXCLUSIVE_MUL:
90         return OPERATOR_MUL;
91     case OPTYPE_MIN:
92     case OPTYPE_INCLUSIVE_MIN:
93     case OPTYPE_EXCLUSIVE_MIN:
94         return OPERATOR_MIN;
95     case OPTYPE_MAX:
96     case OPTYPE_INCLUSIVE_MAX:
97     case OPTYPE_EXCLUSIVE_MAX:
98         return OPERATOR_MAX;
99     case OPTYPE_AND:
100     case OPTYPE_INCLUSIVE_AND:
101     case OPTYPE_EXCLUSIVE_AND:
102         return OPERATOR_AND;
103     case OPTYPE_OR:
104     case OPTYPE_INCLUSIVE_OR:
105     case OPTYPE_EXCLUSIVE_OR:
106         return OPERATOR_OR;
107     case OPTYPE_XOR:
108     case OPTYPE_INCLUSIVE_XOR:
109     case OPTYPE_EXCLUSIVE_XOR:
110         return OPERATOR_XOR;
111     default:
112         DE_FATAL("Unsupported op type");
113         return OPERATOR_ADD;
114     }
115 }
116 
getScanType(OpType opType)117 static ScanType getScanType(OpType opType)
118 {
119     switch (opType)
120     {
121     case OPTYPE_ADD:
122     case OPTYPE_MUL:
123     case OPTYPE_MIN:
124     case OPTYPE_MAX:
125     case OPTYPE_AND:
126     case OPTYPE_OR:
127     case OPTYPE_XOR:
128         return SCAN_REDUCE;
129     case OPTYPE_INCLUSIVE_ADD:
130     case OPTYPE_INCLUSIVE_MUL:
131     case OPTYPE_INCLUSIVE_MIN:
132     case OPTYPE_INCLUSIVE_MAX:
133     case OPTYPE_INCLUSIVE_AND:
134     case OPTYPE_INCLUSIVE_OR:
135     case OPTYPE_INCLUSIVE_XOR:
136         return SCAN_INCLUSIVE;
137     case OPTYPE_EXCLUSIVE_ADD:
138     case OPTYPE_EXCLUSIVE_MUL:
139     case OPTYPE_EXCLUSIVE_MIN:
140     case OPTYPE_EXCLUSIVE_MAX:
141     case OPTYPE_EXCLUSIVE_AND:
142     case OPTYPE_EXCLUSIVE_OR:
143     case OPTYPE_EXCLUSIVE_XOR:
144         return SCAN_EXCLUSIVE;
145     default:
146         DE_FATAL("Unsupported op type");
147         return SCAN_REDUCE;
148     }
149 }
150 
checkVertexPipelineStages(const void * internalData,vector<const void * > datas,uint32_t width,uint32_t)151 static bool checkVertexPipelineStages(const void *internalData, vector<const void *> datas, uint32_t width, uint32_t)
152 {
153     DE_UNREF(internalData);
154 
155     return subgroups::check(datas, width, 0xFFFFFF);
156 }
157 
checkComputeOrMesh(const void * internalData,vector<const void * > datas,const uint32_t numWorkgroups[3],const uint32_t localSize[3],uint32_t)158 static bool checkComputeOrMesh(const void *internalData, vector<const void *> datas, const uint32_t numWorkgroups[3],
159                                const uint32_t localSize[3], uint32_t)
160 {
161     DE_UNREF(internalData);
162 
163     return subgroups::checkComputeOrMesh(datas, numWorkgroups, localSize, 0xFFFFFF);
164 }
165 
getOpTypeName(Operator op,ScanType scanType)166 string getOpTypeName(Operator op, ScanType scanType)
167 {
168     return getScanOpName("subgroup", "", op, scanType);
169 }
170 
getOpTypeNamePartitioned(Operator op,ScanType scanType)171 string getOpTypeNamePartitioned(Operator op, ScanType scanType)
172 {
173     return getScanOpName("subgroupPartitioned", "NV", op, scanType);
174 }
175 
getExtHeader(const CaseDefinition & caseDef)176 string getExtHeader(const CaseDefinition &caseDef)
177 {
178     return "#extension GL_NV_shader_subgroup_partitioned: enable\n"
179            "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
180            "#extension GL_KHR_shader_subgroup_ballot: enable\n" +
181            subgroups::getAdditionalExtensionForFormat(caseDef.format);
182 }
183 
getTestString(const CaseDefinition & caseDef)184 string getTestString(const CaseDefinition &caseDef)
185 {
186     Operator op = caseDef.op;
187     ScanType st = caseDef.scanType;
188 
189     // NOTE: tempResult can't have anything in bits 31:24 to avoid int->float
190     // conversion overflow in framebuffer tests.
191     string fmt = subgroups::getFormatNameForGLSL(caseDef.format);
192     string bdy = "  uvec4 mask = subgroupBallot(true);\n"
193                  "  uint tempResult = 0;\n"
194                  "  uint id = gl_SubgroupInvocationID;\n";
195 
196     // Test the case where the partition has a single subset with all invocations in it.
197     // This should generate the same result as the non-partitioned function.
198     bdy += "  uvec4 allBallot = mask;\n"
199            "  " +
200            fmt + " allResult = " + getOpTypeNamePartitioned(op, st) +
201            "(data[gl_SubgroupInvocationID], allBallot);\n"
202            "  " +
203            fmt + " refResult = " + getOpTypeName(op, st) +
204            "(data[gl_SubgroupInvocationID]);\n"
205            "  if (" +
206            getCompare(op, caseDef.format, "allResult", "refResult") +
207            ") {\n"
208            "      tempResult |= 0x1;\n"
209            "  }\n";
210 
211     // The definition of a partition doesn't forbid bits corresponding to inactive
212     // invocations being in the subset with active invocations. In other words, test that
213     // bits corresponding to inactive invocations are ignored.
214     bdy += "  if (0 == (gl_SubgroupInvocationID % 2)) {\n"
215            "    " +
216            fmt + " allResult = " + getOpTypeNamePartitioned(op, st) +
217            "(data[gl_SubgroupInvocationID], allBallot);\n"
218            "    " +
219            fmt + " refResult = " + getOpTypeName(op, st) +
220            "(data[gl_SubgroupInvocationID]);\n"
221            "    if (" +
222            getCompare(op, caseDef.format, "allResult", "refResult") +
223            ") {\n"
224            "        tempResult |= 0x2;\n"
225            "    }\n"
226            "  } else {\n"
227            "    tempResult |= 0x2;\n"
228            "  }\n";
229 
230     // Test the case where the partition has each invocation in a unique subset. For
231     // exclusive ops, the result is identity. For reduce/inclusive, it's the original value.
232     string expectedSelfResult = "data[gl_SubgroupInvocationID]";
233     if (st == SCAN_EXCLUSIVE)
234         expectedSelfResult = getIdentity(op, caseDef.format);
235 
236     bdy += "  uvec4 selfBallot = subgroupPartitionNV(gl_SubgroupInvocationID);\n"
237            "  " +
238            fmt + " selfResult = " + getOpTypeNamePartitioned(op, st) +
239            "(data[gl_SubgroupInvocationID], selfBallot);\n"
240            "  if (" +
241            getCompare(op, caseDef.format, "selfResult", expectedSelfResult) +
242            ") {\n"
243            "      tempResult |= 0x4;\n"
244            "  }\n";
245 
246     // Test "random" partitions based on a hash of the invocation id.
247     // This "hash" function produces interesting/randomish partitions.
248     static const char *idhash = "((id%N)+(id%(N+1))-(id%2)+(id/2))%((N+1)/2)";
249 
250     bdy += "  for (uint N = 1; N < 16; ++N) {\n"
251            "    " +
252            fmt + " idhashFmt = " + fmt + "(" + idhash +
253            ");\n"
254            "    uvec4 partitionBallot = subgroupPartitionNV(idhashFmt) & mask;\n"
255            "    " +
256            fmt + " partitionedResult = " + getOpTypeNamePartitioned(op, st) +
257            "(data[gl_SubgroupInvocationID], partitionBallot);\n"
258            "      for (uint i = 0; i < N; ++i) {\n"
259            "        " +
260            fmt + " iFmt = " + fmt +
261            "(i);\n"
262            "        if (" +
263            getCompare(op, caseDef.format, "idhashFmt", "iFmt") +
264            ") {\n"
265            "          " +
266            fmt + " subsetResult = " + getOpTypeName(op, st) +
267            "(data[gl_SubgroupInvocationID]);\n"
268            "          tempResult |= " +
269            getCompare(op, caseDef.format, "partitionedResult", "subsetResult") +
270            " ? (0x4 << N) : 0;\n"
271            "        }\n"
272            "      }\n"
273            "  }\n"
274            // tests in flow control:
275            "  if (1 == (gl_SubgroupInvocationID % 2)) {\n"
276            "    for (uint N = 1; N < 7; ++N) {\n"
277            "      " +
278            fmt + " idhashFmt = " + fmt + "(" + idhash +
279            ");\n"
280            "      uvec4 partitionBallot = subgroupPartitionNV(idhashFmt) & mask;\n"
281            "      " +
282            fmt + " partitionedResult = " + getOpTypeNamePartitioned(op, st) +
283            "(data[gl_SubgroupInvocationID], partitionBallot);\n"
284            "        for (uint i = 0; i < N; ++i) {\n"
285            "          " +
286            fmt + " iFmt = " + fmt +
287            "(i);\n"
288            "          if (" +
289            getCompare(op, caseDef.format, "idhashFmt", "iFmt") +
290            ") {\n"
291            "            " +
292            fmt + " subsetResult = " + getOpTypeName(op, st) +
293            "(data[gl_SubgroupInvocationID]);\n"
294            "            tempResult |= " +
295            getCompare(op, caseDef.format, "partitionedResult", "subsetResult") +
296            " ? (0x20000 << N) : 0;\n"
297            "          }\n"
298            "        }\n"
299            "    }\n"
300            "  } else {\n"
301            "    tempResult |= 0xFC0000;\n"
302            "  }\n"
303            "  tempRes = tempResult;\n";
304 
305     return bdy;
306 }
307 
initFrameBufferPrograms(SourceCollections & programCollection,CaseDefinition caseDef)308 void initFrameBufferPrograms(SourceCollections &programCollection, CaseDefinition caseDef)
309 {
310     const ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, SPIRV_VERSION_1_3, 0u);
311     const string extHeader      = getExtHeader(caseDef);
312     const string testSrc        = getTestString(caseDef);
313     const bool pointSizeSupport = *caseDef.geometryPointSizeSupported;
314 
315     subgroups::initStdFrameBufferPrograms(programCollection, buildOptions, caseDef.shaderStage, caseDef.format,
316                                           pointSizeSupport, extHeader, testSrc, "");
317 }
318 
initPrograms(SourceCollections & programCollection,CaseDefinition caseDef)319 void initPrograms(SourceCollections &programCollection, CaseDefinition caseDef)
320 {
321     const bool spirv14required =
322         (isAllRayTracingStages(caseDef.shaderStage) || isAllMeshShadingStages(caseDef.shaderStage));
323     const SpirvVersion spirvVersion = spirv14required ? SPIRV_VERSION_1_4 : SPIRV_VERSION_1_3;
324     const ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, spirvVersion, 0u, spirv14required);
325     const string extHeader      = getExtHeader(caseDef);
326     const string testSrc        = getTestString(caseDef);
327     const bool pointSizeSupport = *caseDef.geometryPointSizeSupported;
328 
329     subgroups::initStdPrograms(programCollection, buildOptions, caseDef.shaderStage, caseDef.format, pointSizeSupport,
330                                extHeader, testSrc, "");
331 }
332 
supportedCheck(Context & context,CaseDefinition caseDef)333 void supportedCheck(Context &context, CaseDefinition caseDef)
334 {
335     if (!subgroups::isSubgroupSupported(context))
336         TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
337 
338     if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_PARTITIONED_BIT_NV))
339         TCU_THROW(NotSupportedError, "Device does not support subgroup partitioned operations");
340 
341     if (!subgroups::isFormatSupportedForDevice(context, caseDef.format))
342         TCU_THROW(NotSupportedError, "Device does not support the specified format in subgroup operations");
343 
344     if (caseDef.requires16BitUniformBuffer)
345     {
346         if (!subgroups::is16BitUBOStorageSupported(context))
347         {
348             TCU_THROW(NotSupportedError, "Device does not support the specified format in subgroup operations");
349         }
350     }
351 
352     if (caseDef.requires8BitUniformBuffer)
353     {
354         if (!subgroups::is8BitUBOStorageSupported(context))
355         {
356             TCU_THROW(NotSupportedError, "Device does not support the specified format in subgroup operations");
357         }
358     }
359 
360     if (caseDef.requiredSubgroupSize)
361     {
362         context.requireDeviceFunctionality("VK_EXT_subgroup_size_control");
363 
364         const VkPhysicalDeviceSubgroupSizeControlFeatures &subgroupSizeControlFeatures =
365             context.getSubgroupSizeControlFeatures();
366         const VkPhysicalDeviceSubgroupSizeControlProperties &subgroupSizeControlProperties =
367             context.getSubgroupSizeControlProperties();
368 
369         if (subgroupSizeControlFeatures.subgroupSizeControl == false)
370             TCU_THROW(NotSupportedError, "Device does not support varying subgroup sizes nor required subgroup size");
371 
372         if (subgroupSizeControlFeatures.computeFullSubgroups == false)
373             TCU_THROW(NotSupportedError, "Device does not support full subgroups in compute shaders");
374 
375         if ((subgroupSizeControlProperties.requiredSubgroupSizeStages & caseDef.shaderStage) != caseDef.shaderStage)
376             TCU_THROW(NotSupportedError, "Required subgroup size is not supported for shader stage");
377     }
378 
379     *caseDef.geometryPointSizeSupported = subgroups::isTessellationAndGeometryPointSizeSupported(context);
380 
381     if (isAllRayTracingStages(caseDef.shaderStage))
382     {
383         context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
384     }
385     else if (isAllMeshShadingStages(caseDef.shaderStage))
386     {
387         context.requireDeviceCoreFeature(DEVICE_CORE_FEATURE_VERTEX_PIPELINE_STORES_AND_ATOMICS);
388         context.requireDeviceFunctionality("VK_EXT_mesh_shader");
389 
390         if ((caseDef.shaderStage & VK_SHADER_STAGE_TASK_BIT_EXT) != 0u)
391         {
392             const auto &features = context.getMeshShaderFeaturesEXT();
393             if (!features.taskShader)
394                 TCU_THROW(NotSupportedError, "Task shaders not supported");
395         }
396     }
397 
398     subgroups::supportedCheckShader(context, caseDef.shaderStage);
399 }
400 
noSSBOtest(Context & context,const CaseDefinition caseDef)401 TestStatus noSSBOtest(Context &context, const CaseDefinition caseDef)
402 {
403     const subgroups::SSBOData inputData{
404         subgroups::SSBOData::InitializeNonZero, //  InputDataInitializeType initializeType;
405         subgroups::SSBOData::LayoutStd140,      //  InputDataLayoutType layout;
406         caseDef.format,                         //  vk::VkFormat format;
407         subgroups::maxSupportedSubgroupSize(),  //  vk::VkDeviceSize numElements;
408         subgroups::SSBOData::BindingUBO,        //  BindingType bindingType;
409     };
410 
411     switch (caseDef.shaderStage)
412     {
413     case VK_SHADER_STAGE_VERTEX_BIT:
414         return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
415                                                     checkVertexPipelineStages);
416     case VK_SHADER_STAGE_GEOMETRY_BIT:
417         return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
418                                                       checkVertexPipelineStages);
419     case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT:
420         return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
421                                                                     checkVertexPipelineStages, caseDef.shaderStage);
422     case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT:
423         return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
424                                                                     checkVertexPipelineStages, caseDef.shaderStage);
425     default:
426         TCU_THROW(InternalError, "Unhandled shader stage");
427     }
428 }
429 
test(Context & context,const CaseDefinition caseDef)430 TestStatus test(Context &context, const CaseDefinition caseDef)
431 {
432     const bool isCompute = isAllComputeStages(caseDef.shaderStage);
433     const bool isMesh    = isAllMeshShadingStages(caseDef.shaderStage);
434     DE_ASSERT(!(isCompute && isMesh));
435 
436     if (isCompute || isMesh)
437     {
438         const VkPhysicalDeviceSubgroupSizeControlProperties &subgroupSizeControlProperties =
439             context.getSubgroupSizeControlProperties();
440         TestLog &log                        = context.getTestContext().getLog();
441         const subgroups::SSBOData inputData = {
442             subgroups::SSBOData::InitializeNonZero, //  InputDataInitializeType initializeType;
443             subgroups::SSBOData::LayoutStd430,      //  InputDataLayoutType layout;
444             caseDef.format,                         //  vk::VkFormat format;
445             subgroups::maxSupportedSubgroupSize(),  //  vk::VkDeviceSize numElements;
446         };
447 
448         if (caseDef.requiredSubgroupSize == false)
449         {
450             if (isCompute)
451                 return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
452                                                   checkComputeOrMesh);
453             else
454                 return subgroups::makeMeshTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkComputeOrMesh);
455         }
456 
457         log << TestLog::Message << "Testing required subgroup size range ["
458             << subgroupSizeControlProperties.minSubgroupSize << ", " << subgroupSizeControlProperties.maxSubgroupSize
459             << "]" << TestLog::EndMessage;
460 
461         // According to the spec, requiredSubgroupSize must be a power-of-two integer.
462         for (uint32_t size = subgroupSizeControlProperties.minSubgroupSize;
463              size <= subgroupSizeControlProperties.maxSubgroupSize; size *= 2)
464         {
465             TestStatus result(QP_TEST_RESULT_INTERNAL_ERROR, "Internal Error");
466 
467             if (isCompute)
468                 result = subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
469                                                     checkComputeOrMesh, size);
470             else
471                 result = subgroups::makeMeshTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
472                                                  checkComputeOrMesh, size);
473 
474             if (result.getCode() != QP_TEST_RESULT_PASS)
475             {
476                 log << TestLog::Message << "subgroupSize " << size << " failed" << TestLog::EndMessage;
477                 return result;
478             }
479         }
480 
481         return TestStatus::pass("OK");
482     }
483     else if (isAllGraphicsStages(caseDef.shaderStage))
484     {
485         const VkShaderStageFlags stages = subgroups::getPossibleGraphicsSubgroupStages(context, caseDef.shaderStage);
486         const subgroups::SSBOData inputData{
487             subgroups::SSBOData::InitializeNonZero, //  InputDataInitializeType initializeType;
488             subgroups::SSBOData::LayoutStd430,      //  InputDataLayoutType layout;
489             caseDef.format,                         //  vk::VkFormat format;
490             subgroups::maxSupportedSubgroupSize(),  //  vk::VkDeviceSize numElements;
491             subgroups::SSBOData::BindingSSBO,       //  bool isImage;
492             4u,                                     //  uint32_t binding;
493             stages,                                 //  vk::VkShaderStageFlags stages;
494         };
495 
496         return subgroups::allStages(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages,
497                                     stages);
498     }
499     else if (isAllRayTracingStages(caseDef.shaderStage))
500     {
501         const VkShaderStageFlags stages = subgroups::getPossibleRayTracingSubgroupStages(context, caseDef.shaderStage);
502         const subgroups::SSBOData inputData{
503             subgroups::SSBOData::InitializeNonZero, //  InputDataInitializeType initializeType;
504             subgroups::SSBOData::LayoutStd430,      //  InputDataLayoutType layout;
505             caseDef.format,                         //  vk::VkFormat format;
506             subgroups::maxSupportedSubgroupSize(),  //  vk::VkDeviceSize numElements;
507             subgroups::SSBOData::BindingSSBO,       //  bool isImage;
508             6u,                                     //  uint32_t binding;
509             stages,                                 //  vk::VkShaderStageFlags stages;
510         };
511 
512         return subgroups::allRayTracingStages(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
513                                               checkVertexPipelineStages, stages);
514     }
515     else
516         TCU_THROW(InternalError, "Unknown stage or invalid stage set");
517 }
518 } // namespace
519 
520 namespace vkt
521 {
522 namespace subgroups
523 {
createSubgroupsPartitionedTests(TestContext & testCtx)524 TestCaseGroup *createSubgroupsPartitionedTests(TestContext &testCtx)
525 {
526     de::MovePtr<TestCaseGroup> group(new TestCaseGroup(testCtx, "partitioned"));
527     de::MovePtr<TestCaseGroup> graphicGroup(new TestCaseGroup(testCtx, "graphics"));
528     de::MovePtr<TestCaseGroup> computeGroup(new TestCaseGroup(testCtx, "compute"));
529     de::MovePtr<TestCaseGroup> meshGroup(new TestCaseGroup(testCtx, "mesh"));
530     de::MovePtr<TestCaseGroup> framebufferGroup(new TestCaseGroup(testCtx, "framebuffer"));
531     de::MovePtr<TestCaseGroup> raytracingGroup(new TestCaseGroup(testCtx, "ray_tracing"));
532     const VkShaderStageFlags fbStages[] = {
533         VK_SHADER_STAGE_VERTEX_BIT,
534         VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
535         VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
536         VK_SHADER_STAGE_GEOMETRY_BIT,
537     };
538     const VkShaderStageFlags meshStages[] = {
539         VK_SHADER_STAGE_MESH_BIT_EXT,
540         VK_SHADER_STAGE_TASK_BIT_EXT,
541     };
542     const bool boolValues[] = {false, true};
543 
544     {
545         const vector<VkFormat> formats = subgroups::getAllFormats();
546 
547         for (size_t formatIndex = 0; formatIndex < formats.size(); ++formatIndex)
548         {
549             const VkFormat format           = formats[formatIndex];
550             const string formatName         = subgroups::getFormatNameForGLSL(format);
551             const bool isBool               = subgroups::isFormatBool(format);
552             const bool isFloat              = subgroups::isFormatFloat(format);
553             const bool needs8BitUBOStorage  = isFormat8bitTy(format);
554             const bool needs16BitUBOStorage = isFormat16BitTy(format);
555 
556             for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
557             {
558                 const OpType opType    = static_cast<OpType>(opTypeIndex);
559                 const Operator op      = getOperator(opType);
560                 const ScanType st      = getScanType(opType);
561                 const bool isBitwiseOp = (op == OPERATOR_AND || op == OPERATOR_OR || op == OPERATOR_XOR);
562 
563                 // Skip float with bitwise category.
564                 if (isFloat && isBitwiseOp)
565                     continue;
566 
567                 // Skip bool when its not the bitwise category.
568                 if (isBool && !isBitwiseOp)
569                     continue;
570 
571                 const string name = de::toLower(getOpTypeName(op, st)) + "_" + formatName;
572 
573                 for (size_t groupSizeNdx = 0; groupSizeNdx < DE_LENGTH_OF_ARRAY(boolValues); ++groupSizeNdx)
574                 {
575                     const bool requiredSubgroupSize = boolValues[groupSizeNdx];
576                     const string testName           = name + (requiredSubgroupSize ? "_requiredsubgroupsize" : "");
577                     const CaseDefinition caseDef    = {
578                         op,                            //  Operator op;
579                         st,                            //  ScanType scanType;
580                         VK_SHADER_STAGE_COMPUTE_BIT,   //  VkShaderStageFlags shaderStage;
581                         format,                        //  VkFormat format;
582                         de::SharedPtr<bool>(new bool), //  de::SharedPtr<bool> geometryPointSizeSupported;
583                         requiredSubgroupSize,          //  bool requiredSubgroupSize;
584                         false,                         //  bool requires8BitUniformBuffer;
585                         false,                         //  bool requires16BitUniformBuffer;
586                     };
587 
588                     addFunctionCaseWithPrograms(computeGroup.get(), testName, supportedCheck, initPrograms, test,
589                                                 caseDef);
590                 }
591 
592                 for (size_t groupSizeNdx = 0; groupSizeNdx < DE_LENGTH_OF_ARRAY(boolValues); ++groupSizeNdx)
593                 {
594                     for (const auto &stage : meshStages)
595                     {
596                         const bool requiredSubgroupSize = boolValues[groupSizeNdx];
597                         const string testName = name + (requiredSubgroupSize ? "_requiredsubgroupsize" : "") + "_" +
598                                                 getShaderStageName(stage);
599                         const CaseDefinition caseDef = {
600                             op,                            //  Operator op;
601                             st,                            //  ScanType scanType;
602                             stage,                         //  VkShaderStageFlags shaderStage;
603                             format,                        //  VkFormat format;
604                             de::SharedPtr<bool>(new bool), //  de::SharedPtr<bool> geometryPointSizeSupported;
605                             requiredSubgroupSize,          //  bool requiredSubgroupSize;
606                             false,                         //  bool requires8BitUniformBuffer;
607                             false,                         //  bool requires16BitUniformBuffer;
608                         };
609 
610                         addFunctionCaseWithPrograms(meshGroup.get(), testName, supportedCheck, initPrograms, test,
611                                                     caseDef);
612                     }
613                 }
614 
615                 {
616                     const CaseDefinition caseDef = {
617                         op,                            //  Operator op;
618                         st,                            //  ScanType scanType;
619                         VK_SHADER_STAGE_ALL_GRAPHICS,  //  VkShaderStageFlags shaderStage;
620                         format,                        //  VkFormat format;
621                         de::SharedPtr<bool>(new bool), //  de::SharedPtr<bool> geometryPointSizeSupported;
622                         false,                         //  bool requiredSubgroupSize;
623                         false,                         //  bool requires8BitUniformBuffer;
624                         false                          //  bool requires16BitUniformBuffer;
625                     };
626 
627                     addFunctionCaseWithPrograms(graphicGroup.get(), name, supportedCheck, initPrograms, test, caseDef);
628                 }
629 
630                 for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(fbStages); ++stageIndex)
631                 {
632                     const CaseDefinition caseDef = {
633                         op,                            //  Operator op;
634                         st,                            //  ScanType scanType;
635                         fbStages[stageIndex],          //  VkShaderStageFlags shaderStage;
636                         format,                        //  VkFormat format;
637                         de::SharedPtr<bool>(new bool), //  de::SharedPtr<bool> geometryPointSizeSupported;
638                         false,                         //  bool requiredSubgroupSize;
639                         bool(needs8BitUBOStorage),     //  bool requires8BitUniformBuffer;
640                         bool(needs16BitUBOStorage)     //  bool requires16BitUniformBuffer;
641                     };
642                     const string testName = name + "_" + getShaderStageName(caseDef.shaderStage);
643 
644                     addFunctionCaseWithPrograms(framebufferGroup.get(), testName, supportedCheck,
645                                                 initFrameBufferPrograms, noSSBOtest, caseDef);
646                 }
647             }
648         }
649     }
650 
651     {
652         const vector<VkFormat> formats = subgroups::getAllRayTracingFormats();
653 
654         for (size_t formatIndex = 0; formatIndex < formats.size(); ++formatIndex)
655         {
656             const VkFormat format   = formats[formatIndex];
657             const string formatName = subgroups::getFormatNameForGLSL(format);
658             const bool isBool       = subgroups::isFormatBool(format);
659             const bool isFloat      = subgroups::isFormatFloat(format);
660 
661             for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
662             {
663                 const OpType opType    = static_cast<OpType>(opTypeIndex);
664                 const Operator op      = getOperator(opType);
665                 const ScanType st      = getScanType(opType);
666                 const bool isBitwiseOp = (op == OPERATOR_AND || op == OPERATOR_OR || op == OPERATOR_XOR);
667 
668                 // Skip float with bitwise category.
669                 if (isFloat && isBitwiseOp)
670                     continue;
671 
672                 // Skip bool when its not the bitwise category.
673                 if (isBool && !isBitwiseOp)
674                     continue;
675 
676                 {
677                     const CaseDefinition caseDef = {
678                         op,                            //  Operator op;
679                         st,                            //  ScanType scanType;
680                         SHADER_STAGE_ALL_RAY_TRACING,  //  VkShaderStageFlags shaderStage;
681                         format,                        //  VkFormat format;
682                         de::SharedPtr<bool>(new bool), //  de::SharedPtr<bool> geometryPointSizeSupported;
683                         false,                         //  bool requiredSubgroupSize;
684                         false,                         //  bool requires8BitUniformBuffer;
685                         false                          //  bool requires16BitUniformBuffer;
686                     };
687                     const string name = de::toLower(getOpTypeName(op, st)) + "_" + formatName;
688 
689                     addFunctionCaseWithPrograms(raytracingGroup.get(), name, supportedCheck, initPrograms, test,
690                                                 caseDef);
691                 }
692             }
693         }
694     }
695 
696     group->addChild(graphicGroup.release());
697     group->addChild(computeGroup.release());
698     group->addChild(framebufferGroup.release());
699     group->addChild(raytracingGroup.release());
700     group->addChild(meshGroup.release());
701 
702     return group.release();
703 }
704 } // namespace subgroups
705 } // namespace vkt
706