1 /*------------------------------------------------------------------------
2  * OpenGL Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017-2019 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  * Copyright (c) 2019 NVIDIA Corporation.
8  *
9  * Licensed under the Apache License, Version 2.0 (the "License");
10  * you may not use this file except in compliance with the License.
11  * You may obtain a copy of the License at
12  *
13  *      http://www.apache.org/licenses/LICENSE-2.0
14  *
15  * Unless required by applicable law or agreed to in writing, software
16  * distributed under the License is distributed on an "AS IS" BASIS,
17  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18  * See the License for the specific language governing permissions and
19  * limitations under the License.
20  *
21  */ /*!
22  * \file
23  * \brief Subgroups Tests
24  */ /*--------------------------------------------------------------------*/
25 
26 #include "glcSubgroupsBallotOtherTests.hpp"
27 #include "glcSubgroupsTestsUtils.hpp"
28 
29 #include <string>
30 #include <vector>
31 
32 using namespace tcu;
33 using namespace std;
34 
35 namespace glc
36 {
37 namespace subgroups
38 {
39 namespace
40 {
41 enum OpType
42 {
43     OPTYPE_INVERSE_BALLOT = 0,
44     OPTYPE_BALLOT_BIT_EXTRACT,
45     OPTYPE_BALLOT_BIT_COUNT,
46     OPTYPE_BALLOT_INCLUSIVE_BIT_COUNT,
47     OPTYPE_BALLOT_EXCLUSIVE_BIT_COUNT,
48     OPTYPE_BALLOT_FIND_LSB,
49     OPTYPE_BALLOT_FIND_MSB,
50     OPTYPE_LAST
51 };
52 
checkVertexPipelineStages(std::vector<const void * > datas,uint32_t width,uint32_t)53 static bool checkVertexPipelineStages(std::vector<const void *> datas, uint32_t width, uint32_t)
54 {
55     return glc::subgroups::check(datas, width, 0xf);
56 }
57 
checkComputeStage(std::vector<const void * > datas,const uint32_t numWorkgroups[3],const uint32_t localSize[3],uint32_t)58 static bool checkComputeStage(std::vector<const void *> datas, const uint32_t numWorkgroups[3],
59                               const uint32_t localSize[3], uint32_t)
60 {
61     return glc::subgroups::checkCompute(datas, numWorkgroups, localSize, 0xf);
62 }
63 
getOpTypeName(int opType)64 std::string getOpTypeName(int opType)
65 {
66     switch (opType)
67     {
68     default:
69         DE_FATAL("Unsupported op type");
70         return "";
71     case OPTYPE_INVERSE_BALLOT:
72         return "subgroupInverseBallot";
73     case OPTYPE_BALLOT_BIT_EXTRACT:
74         return "subgroupBallotBitExtract";
75     case OPTYPE_BALLOT_BIT_COUNT:
76         return "subgroupBallotBitCount";
77     case OPTYPE_BALLOT_INCLUSIVE_BIT_COUNT:
78         return "subgroupBallotInclusiveBitCount";
79     case OPTYPE_BALLOT_EXCLUSIVE_BIT_COUNT:
80         return "subgroupBallotExclusiveBitCount";
81     case OPTYPE_BALLOT_FIND_LSB:
82         return "subgroupBallotFindLSB";
83     case OPTYPE_BALLOT_FIND_MSB:
84         return "subgroupBallotFindMSB";
85     }
86 }
87 
88 struct CaseDefinition
89 {
90     int opType;
91     ShaderStageFlags shaderStage;
92 };
93 
getBodySource(CaseDefinition caseDef)94 std::string getBodySource(CaseDefinition caseDef)
95 {
96     std::ostringstream bdy;
97 
98     bdy << "  uvec4 allOnes = uvec4(0xFFFFFFFF);\n"
99         << "  uvec4 allZeros = uvec4(0);\n"
100         << "  uint tempResult = 0u;\n"
101         << "#define MAKE_HIGH_BALLOT_RESULT(i) uvec4("
102         << "i >= 32u ? 0u : (0xFFFFFFFFu << i), "
103         << "i >= 64u ? 0u : (0xFFFFFFFFu << ((i < 32u) ? 0u : (i - 32u))), "
104         << "i >= 96u ? 0u : (0xFFFFFFFFu << ((i < 64u) ? 0u : (i - 64u))), "
105         << "i == 128u ? 0u : (0xFFFFFFFFu << ((i < 96u) ? 0u : (i - 96u))))\n"
106         << "#define MAKE_SINGLE_BIT_BALLOT_RESULT(i) uvec4("
107         << "i >= 32u ? 0u : 0x1u << i, "
108         << "i < 32u || i >= 64u ? 0u : 0x1u << (i - 32u), "
109         << "i < 64u || i >= 96u ? 0u : 0x1u << (i - 64u), "
110         << "i < 96u ? 0u : 0x1u << (i - 96u))\n";
111 
112     switch (caseDef.opType)
113     {
114     default:
115         DE_FATAL("Unknown op type!");
116         break;
117     case OPTYPE_INVERSE_BALLOT:
118         bdy << "  tempResult |= subgroupInverseBallot(allOnes) ? 0x1u : 0u;\n"
119             << "  tempResult |= subgroupInverseBallot(allZeros) ? 0u : 0x2u;\n"
120             << "  tempResult |= subgroupInverseBallot(subgroupBallot(true)) ? 0x4u : 0u;\n"
121             << "  tempResult |= 0x8u;\n";
122         break;
123     case OPTYPE_BALLOT_BIT_EXTRACT:
124         bdy << "  tempResult |= subgroupBallotBitExtract(allOnes, gl_SubgroupInvocationID) ? 0x1u : 0u;\n"
125             << "  tempResult |= subgroupBallotBitExtract(allZeros, gl_SubgroupInvocationID) ? 0u : 0x2u;\n"
126             << "  tempResult |= subgroupBallotBitExtract(subgroupBallot(true), gl_SubgroupInvocationID) ? 0x4u : 0u;\n"
127             << "  tempResult |= 0x8u;\n"
128             << "  for (uint i = 0u; i < gl_SubgroupSize; i++)\n"
129             << "  {\n"
130             << "    if (!subgroupBallotBitExtract(allOnes, gl_SubgroupInvocationID))\n"
131             << "    {\n"
132             << "      tempResult &= ~0x8u;\n"
133             << "    }\n"
134             << "  }\n";
135         break;
136     case OPTYPE_BALLOT_BIT_COUNT:
137         bdy << "  tempResult |= gl_SubgroupSize == subgroupBallotBitCount(allOnes) ? 0x1u : 0u;\n"
138             << "  tempResult |= 0u == subgroupBallotBitCount(allZeros) ? 0x2u : 0u;\n"
139             << "  tempResult |= 0u < subgroupBallotBitCount(subgroupBallot(true)) ? 0x4u : 0u;\n"
140             << "  tempResult |= 0u == subgroupBallotBitCount(MAKE_HIGH_BALLOT_RESULT(gl_SubgroupSize)) ? 0x8u : 0u;\n";
141         break;
142     case OPTYPE_BALLOT_INCLUSIVE_BIT_COUNT:
143         bdy << "  uint inclusiveOffset = gl_SubgroupInvocationID + 1u;\n"
144             << "  tempResult |= inclusiveOffset == subgroupBallotInclusiveBitCount(allOnes) ? 0x1u : 0u;\n"
145             << "  tempResult |= 0u == subgroupBallotInclusiveBitCount(allZeros) ? 0x2u : 0u;\n"
146             << "  tempResult |= 0u < subgroupBallotInclusiveBitCount(subgroupBallot(true)) ? 0x4u : 0u;\n"
147             << "  tempResult |= 0x8u;\n"
148             << "  uvec4 inclusiveUndef = MAKE_HIGH_BALLOT_RESULT(inclusiveOffset);\n"
149             << "  bool undefTerritory = false;\n"
150             << "  for (uint i = 0u; i <= 128u; i++)\n"
151             << "  {\n"
152             << "    uvec4 iUndef = MAKE_HIGH_BALLOT_RESULT(i);\n"
153             << "    if (iUndef == inclusiveUndef)"
154             << "    {\n"
155             << "      undefTerritory = true;\n"
156             << "    }\n"
157             << "    uint inclusiveBitCount = subgroupBallotInclusiveBitCount(iUndef);\n"
158             << "    if (undefTerritory && (0u != inclusiveBitCount))\n"
159             << "    {\n"
160             << "      tempResult &= ~0x8u;\n"
161             << "    }\n"
162             << "    else if (!undefTerritory && (0u == inclusiveBitCount))\n"
163             << "    {\n"
164             << "      tempResult &= ~0x8u;\n"
165             << "    }\n"
166             << "  }\n";
167         break;
168     case OPTYPE_BALLOT_EXCLUSIVE_BIT_COUNT:
169         bdy << "  uint exclusiveOffset = gl_SubgroupInvocationID;\n"
170             << "  tempResult |= exclusiveOffset == subgroupBallotExclusiveBitCount(allOnes) ? 0x1u : 0u;\n"
171             << "  tempResult |= 0u == subgroupBallotExclusiveBitCount(allZeros) ? 0x2u : 0u;\n"
172             << "  tempResult |= 0x4u;\n"
173             << "  tempResult |= 0x8u;\n"
174             << "  uvec4 exclusiveUndef = MAKE_HIGH_BALLOT_RESULT(exclusiveOffset);\n"
175             << "  bool undefTerritory = false;\n"
176             << "  for (uint i = 0u; i <= 128u; i++)\n"
177             << "  {\n"
178             << "    uvec4 iUndef = MAKE_HIGH_BALLOT_RESULT(i);\n"
179             << "    if (iUndef == exclusiveUndef)"
180             << "    {\n"
181             << "      undefTerritory = true;\n"
182             << "    }\n"
183             << "    uint exclusiveBitCount = subgroupBallotExclusiveBitCount(iUndef);\n"
184             << "    if (undefTerritory && (0u != exclusiveBitCount))\n"
185             << "    {\n"
186             << "      tempResult &= ~0x4u;\n"
187             << "    }\n"
188             << "    else if (!undefTerritory && (0u == exclusiveBitCount))\n"
189             << "    {\n"
190             << "      tempResult &= ~0x8u;\n"
191             << "    }\n"
192             << "  }\n";
193         break;
194     case OPTYPE_BALLOT_FIND_LSB:
195         bdy << "  tempResult |= 0u == subgroupBallotFindLSB(allOnes) ? 0x1u : 0u;\n"
196             << "  if (subgroupElect())\n"
197             << "  {\n"
198             << "    tempResult |= 0x2u;\n"
199             << "  }\n"
200             << "  else\n"
201             << "  {\n"
202             << "    tempResult |= 0u < subgroupBallotFindLSB(subgroupBallot(true)) ? 0x2u : 0u;\n"
203             << "  }\n"
204             << "  tempResult |= gl_SubgroupSize > subgroupBallotFindLSB(subgroupBallot(true)) ? 0x4u : 0u;\n"
205             << "  tempResult |= 0x8u;\n"
206             << "  for (uint i = 0u; i < gl_SubgroupSize; i++)\n"
207             << "  {\n"
208             << "    if (i != subgroupBallotFindLSB(MAKE_HIGH_BALLOT_RESULT(i)))\n"
209             << "    {\n"
210             << "      tempResult &= ~0x8u;\n"
211             << "    }\n"
212             << "  }\n";
213         break;
214     case OPTYPE_BALLOT_FIND_MSB:
215         bdy << "  tempResult |= (gl_SubgroupSize - 1u) == subgroupBallotFindMSB(allOnes) ? 0x1u : 0u;\n"
216             << "  if (subgroupElect())\n"
217             << "  {\n"
218             << "    tempResult |= 0x2u;\n"
219             << "  }\n"
220             << "  else\n"
221             << "  {\n"
222             << "    tempResult |= 0u < subgroupBallotFindMSB(subgroupBallot(true)) ? 0x2u : 0u;\n"
223             << "  }\n"
224             << "  tempResult |= gl_SubgroupSize > subgroupBallotFindMSB(subgroupBallot(true)) ? 0x4u : 0u;\n"
225             << "  tempResult |= 0x8u;\n"
226             << "  for (uint i = 0u; i < gl_SubgroupSize; i++)\n"
227             << "  {\n"
228             << "    if (i != subgroupBallotFindMSB(MAKE_SINGLE_BIT_BALLOT_RESULT(i)))\n"
229             << "    {\n"
230             << "      tempResult &= ~0x8u;\n"
231             << "    }\n"
232             << "  }\n";
233         break;
234     }
235     return bdy.str();
236 }
237 
initFrameBufferPrograms(SourceCollections & programCollection,CaseDefinition caseDef)238 void initFrameBufferPrograms(SourceCollections &programCollection, CaseDefinition caseDef)
239 {
240     subgroups::setFragmentShaderFrameBuffer(programCollection);
241 
242     if (SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
243         subgroups::setVertexShaderFrameBuffer(programCollection);
244 
245     std::string bdyStr = getBodySource(caseDef);
246 
247     if (SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
248     {
249         std::ostringstream vertex;
250         vertex << "${VERSION_DECL}\n"
251                << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
252                << "layout(location = 0) in highp vec4 in_position;\n"
253                << "layout(location = 0) out float out_color;\n"
254                << "\n"
255                << "void main (void)\n"
256                << "{\n"
257                << bdyStr << "  out_color = float(tempResult);\n"
258                << "  gl_Position = in_position;\n"
259                << "  gl_PointSize = 1.0f;\n"
260                << "}\n";
261         programCollection.add("vert") << glu::VertexSource(vertex.str());
262     }
263     else if (SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
264     {
265         std::ostringstream geometry;
266 
267         geometry << "${VERSION_DECL}\n"
268                  << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
269                  << "layout(points) in;\n"
270                  << "layout(points, max_vertices = 1) out;\n"
271                  << "layout(location = 0) out float out_color;\n"
272                  << "void main (void)\n"
273                  << "{\n"
274                  << bdyStr << "  out_color = float(tempResult);\n"
275                  << "  gl_Position = gl_in[0].gl_Position;\n"
276                  << "  EmitVertex();\n"
277                  << "  EndPrimitive();\n"
278                  << "}\n";
279 
280         programCollection.add("geometry") << glu::GeometrySource(geometry.str());
281     }
282     else if (SHADER_STAGE_TESS_CONTROL_BIT == caseDef.shaderStage)
283     {
284         std::ostringstream controlSource;
285 
286         controlSource << "${VERSION_DECL}\n"
287                       << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
288                       << "layout(vertices = 2) out;\n"
289                       << "layout(location = 0) out float out_color[];\n"
290                       << "\n"
291                       << "void main (void)\n"
292                       << "{\n"
293                       << "  if (gl_InvocationID == 0)\n"
294                       << "  {\n"
295                       << "    gl_TessLevelOuter[0] = 1.0f;\n"
296                       << "    gl_TessLevelOuter[1] = 1.0f;\n"
297                       << "  }\n"
298                       << bdyStr << "  out_color[gl_InvocationID ] = float(tempResult);\n"
299                       << "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
300                       << "}\n";
301 
302         programCollection.add("tesc") << glu::TessellationControlSource(controlSource.str());
303         subgroups::setTesEvalShaderFrameBuffer(programCollection);
304     }
305     else if (SHADER_STAGE_TESS_EVALUATION_BIT == caseDef.shaderStage)
306     {
307         std::ostringstream evaluationSource;
308         evaluationSource << "${VERSION_DECL}\n"
309                          << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
310                          << "layout(isolines, equal_spacing, ccw ) in;\n"
311                          << "layout(location = 0) out float out_color;\n"
312                          << "void main (void)\n"
313                          << "{\n"
314                          << bdyStr << "  out_color  = float(tempResult);\n"
315                          << "  gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
316                          << "}\n";
317 
318         subgroups::setTesCtrlShaderFrameBuffer(programCollection);
319         programCollection.add("tese") << glu::TessellationEvaluationSource(evaluationSource.str());
320     }
321     else
322     {
323         DE_FATAL("Unsupported shader stage");
324     }
325 }
326 
initPrograms(SourceCollections & programCollection,CaseDefinition caseDef)327 void initPrograms(SourceCollections &programCollection, CaseDefinition caseDef)
328 {
329     std::string bdyStr = getBodySource(caseDef);
330 
331     if (SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
332     {
333         std::ostringstream src;
334 
335         src << "${VERSION_DECL}\n"
336             << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
337             << "layout (${LOCAL_SIZE_X}, ${LOCAL_SIZE_Y}, ${LOCAL_SIZE_Z}) in;\n"
338             << "layout(binding = 0, std430) buffer Buffer0\n"
339             << "{\n"
340             << "  uint result[];\n"
341             << "};\n"
342             << "\n"
343             << "void main (void)\n"
344             << "{\n"
345             << "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
346             << "  highp uint offset = globalSize.x * ((globalSize.y * "
347                "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
348                "gl_GlobalInvocationID.x;\n"
349             << bdyStr << "  result[offset] = tempResult;\n"
350             << "}\n";
351 
352         programCollection.add("comp") << glu::ComputeSource(src.str());
353     }
354     else
355     {
356         const string vertex =
357             "${VERSION_DECL}\n"
358             "#extension GL_KHR_shader_subgroup_ballot: enable\n"
359             "layout(binding = 0, std430) buffer Buffer0\n"
360             "{\n"
361             "  uint result[];\n"
362             "} b0;\n"
363             "\n"
364             "void main (void)\n"
365             "{\n" +
366             bdyStr +
367             "  b0.result[gl_VertexID] = tempResult;\n"
368             "  float pixelSize = 2.0f/1024.0f;\n"
369             "  float pixelPosition = pixelSize/2.0f - 1.0f;\n"
370             "  gl_Position = vec4(float(gl_VertexID) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
371             "  gl_PointSize = 1.0f;\n"
372             "}\n";
373 
374         const string tesc = "${VERSION_DECL}\n"
375                             "#extension GL_KHR_shader_subgroup_ballot: enable\n"
376                             "layout(vertices=1) out;\n"
377                             "layout(binding = 1, std430) buffer Buffer1\n"
378                             "{\n"
379                             "  uint result[];\n"
380                             "} b1;\n"
381                             "\n"
382                             "void main (void)\n"
383                             "{\n" +
384                             bdyStr +
385                             "  b1.result[gl_PrimitiveID] = tempResult;\n"
386                             "  if (gl_InvocationID == 0)\n"
387                             "  {\n"
388                             "    gl_TessLevelOuter[0] = 1.0f;\n"
389                             "    gl_TessLevelOuter[1] = 1.0f;\n"
390                             "  }\n"
391                             "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
392                             "}\n";
393 
394         const string tese = "${VERSION_DECL}\n"
395                             "#extension GL_KHR_shader_subgroup_ballot: enable\n"
396                             "layout(isolines) in;\n"
397                             "layout(binding = 2, std430) buffer Buffer2\n"
398                             "{\n"
399                             "  uint result[];\n"
400                             "} b2;\n"
401                             "\n"
402                             "void main (void)\n"
403                             "{\n" +
404                             bdyStr +
405                             "  b2.result[gl_PrimitiveID * 2 + int(gl_TessCoord.x + 0.5)] = tempResult;\n"
406                             "  float pixelSize = 2.0f/1024.0f;\n"
407                             "  gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
408                             "}\n";
409 
410         const string geometry =
411             // version string added by addGeometryShadersFromTemplate
412             "#extension GL_KHR_shader_subgroup_ballot: enable\n"
413             "layout(${TOPOLOGY}) in;\n"
414             "layout(points, max_vertices = 1) out;\n"
415             "layout(binding = 3, std430) buffer Buffer3\n"
416             "{\n"
417             "  uint result[];\n"
418             "} b3;\n"
419             "\n"
420             "void main (void)\n"
421             "{\n" +
422             bdyStr +
423             "  b3.result[gl_PrimitiveIDIn] = tempResult;\n"
424             "  gl_Position = gl_in[0].gl_Position;\n"
425             "  EmitVertex();\n"
426             "  EndPrimitive();\n"
427             "}\n";
428 
429         const string fragment = "${VERSION_DECL}\n"
430                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
431                                 "precision highp int;\n"
432                                 "layout(location = 0) out uint result;\n"
433                                 "void main (void)\n"
434                                 "{\n" +
435                                 bdyStr +
436                                 "  result = tempResult;\n"
437                                 "}\n";
438 
439         subgroups::addNoSubgroupShader(programCollection);
440 
441         programCollection.add("vert") << glu::VertexSource(vertex);
442         programCollection.add("tesc") << glu::TessellationControlSource(tesc);
443         programCollection.add("tese") << glu::TessellationEvaluationSource(tese);
444         subgroups::addGeometryShadersFromTemplate(geometry, programCollection);
445         programCollection.add("fragment") << glu::FragmentSource(fragment);
446     }
447 }
448 
supportedCheck(Context & context,CaseDefinition caseDef)449 void supportedCheck(Context &context, CaseDefinition caseDef)
450 {
451     DE_UNREF(caseDef);
452     if (!subgroups::isSubgroupSupported(context))
453         TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
454 
455     if (!subgroups::isSubgroupFeatureSupportedForDevice(context, SUBGROUP_FEATURE_BALLOT_BIT))
456     {
457         TCU_THROW(NotSupportedError, "Device does not support subgroup ballot operations");
458     }
459 }
460 
noSSBOtest(Context & context,const CaseDefinition caseDef)461 tcu::TestStatus noSSBOtest(Context &context, const CaseDefinition caseDef)
462 {
463     if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
464     {
465         if (subgroups::areSubgroupOperationsRequiredForStage(caseDef.shaderStage))
466         {
467             return tcu::TestStatus::fail("Shader stage " + subgroups::getShaderStageName(caseDef.shaderStage) +
468                                          " is required to support subgroup operations!");
469         }
470         else
471         {
472             TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
473         }
474     }
475 
476     if (SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
477         return subgroups::makeVertexFrameBufferTest(context, FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages);
478     else if (SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
479         return subgroups::makeGeometryFrameBufferTest(context, FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages);
480     else if ((SHADER_STAGE_TESS_CONTROL_BIT | SHADER_STAGE_TESS_EVALUATION_BIT) & caseDef.shaderStage)
481         return subgroups::makeTessellationEvaluationFrameBufferTest(context, FORMAT_R32_UINT, DE_NULL, 0,
482                                                                     checkVertexPipelineStages);
483     else
484         TCU_THROW(InternalError, "Unhandled shader stage");
485 }
486 
test(Context & context,const CaseDefinition caseDef)487 tcu::TestStatus test(Context &context, const CaseDefinition caseDef)
488 {
489     if (SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
490     {
491         if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
492         {
493             return tcu::TestStatus::fail("Shader stage " + subgroups::getShaderStageName(caseDef.shaderStage) +
494                                          " is required to support subgroup operations!");
495         }
496         return subgroups::makeComputeTest(context, FORMAT_R32_UINT, DE_NULL, 0, checkComputeStage);
497     }
498     else
499     {
500         int supportedStages = context.getDeqpContext().getContextInfo().getInt(GL_SUBGROUP_SUPPORTED_STAGES_KHR);
501 
502         ShaderStageFlags stages = (ShaderStageFlags)(caseDef.shaderStage & supportedStages);
503 
504         if (SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
505         {
506             if ((stages & SHADER_STAGE_FRAGMENT_BIT) == 0)
507                 TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
508             else
509                 stages = SHADER_STAGE_FRAGMENT_BIT;
510         }
511 
512         if ((ShaderStageFlags)0u == stages)
513             TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
514 
515         return subgroups::allStages(context, FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages, stages);
516     }
517     return tcu::TestStatus::pass("OK");
518 }
519 } // namespace
520 
createSubgroupsBallotOtherTests(deqp::Context & testCtx)521 deqp::TestCaseGroup *createSubgroupsBallotOtherTests(deqp::Context &testCtx)
522 {
523     de::MovePtr<deqp::TestCaseGroup> graphicGroup(
524         new deqp::TestCaseGroup(testCtx, "graphics", "Subgroup ballot other category tests: graphics"));
525     de::MovePtr<deqp::TestCaseGroup> computeGroup(
526         new deqp::TestCaseGroup(testCtx, "compute", "Subgroup ballot other category tests: compute"));
527     de::MovePtr<deqp::TestCaseGroup> framebufferGroup(
528         new deqp::TestCaseGroup(testCtx, "framebuffer", "Subgroup ballot other category tests: framebuffer"));
529 
530     const ShaderStageFlags stages[] = {
531         SHADER_STAGE_VERTEX_BIT,
532         SHADER_STAGE_TESS_EVALUATION_BIT,
533         SHADER_STAGE_TESS_CONTROL_BIT,
534         SHADER_STAGE_GEOMETRY_BIT,
535     };
536 
537     for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
538     {
539         const string op = de::toLower(getOpTypeName(opTypeIndex));
540         {
541             const CaseDefinition caseDef = {opTypeIndex, SHADER_STAGE_COMPUTE_BIT};
542             SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(computeGroup.get(), op, "", supportedCheck,
543                                                                          initPrograms, test, caseDef);
544         }
545 
546         {
547             const CaseDefinition caseDef = {opTypeIndex, SHADER_STAGE_ALL_GRAPHICS};
548             SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(graphicGroup.get(), op, "", supportedCheck,
549                                                                          initPrograms, test, caseDef);
550         }
551 
552         for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
553         {
554             const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex]};
555             SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(
556                 framebufferGroup.get(), op + "_" + getShaderStageName(caseDef.shaderStage), "", supportedCheck,
557                 initFrameBufferPrograms, noSSBOtest, caseDef);
558         }
559     }
560 
561     de::MovePtr<deqp::TestCaseGroup> group(
562         new deqp::TestCaseGroup(testCtx, "ballot_other", "Subgroup ballot other category tests"));
563 
564     group->addChild(graphicGroup.release());
565     group->addChild(computeGroup.release());
566     group->addChild(framebufferGroup.release());
567 
568     return group.release();
569 }
570 
571 } // namespace subgroups
572 } // namespace glc
573