xref: /aosp_15_r20/external/deqp/external/openglcts/modules/common/subgroups/glcSubgroupsClusteredTests.cpp (revision 35238bce31c2a825756842865a792f8cf7f89930)
1 /*------------------------------------------------------------------------
2  * OpenGL Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017-2019 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  * Copyright (c) 2019 NVIDIA Corporation.
8  *
9  * Licensed under the Apache License, Version 2.0 (the "License");
10  * you may not use this file except in compliance with the License.
11  * You may obtain a copy of the License at
12  *
13  *      http://www.apache.org/licenses/LICENSE-2.0
14  *
15  * Unless required by applicable law or agreed to in writing, software
16  * distributed under the License is distributed on an "AS IS" BASIS,
17  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18  * See the License for the specific language governing permissions and
19  * limitations under the License.
20  *
21  */ /*!
22  * \file
23  * \brief Subgroups Tests
24  */ /*--------------------------------------------------------------------*/
25 
26 #include "glcSubgroupsClusteredTests.hpp"
27 #include "glcSubgroupsTestsUtils.hpp"
28 
29 #include <string>
30 #include <vector>
31 
32 using namespace tcu;
33 using namespace std;
34 
35 namespace glc
36 {
37 namespace subgroups
38 {
39 namespace
40 {
41 enum OpType
42 {
43     OPTYPE_CLUSTERED_ADD = 0,
44     OPTYPE_CLUSTERED_MUL,
45     OPTYPE_CLUSTERED_MIN,
46     OPTYPE_CLUSTERED_MAX,
47     OPTYPE_CLUSTERED_AND,
48     OPTYPE_CLUSTERED_OR,
49     OPTYPE_CLUSTERED_XOR,
50     OPTYPE_CLUSTERED_LAST
51 };
52 
checkVertexPipelineStages(std::vector<const void * > datas,uint32_t width,uint32_t)53 static bool checkVertexPipelineStages(std::vector<const void *> datas, uint32_t width, uint32_t)
54 {
55     return glc::subgroups::check(datas, width, 1);
56 }
57 
checkComputeStage(std::vector<const void * > datas,const uint32_t numWorkgroups[3],const uint32_t localSize[3],uint32_t)58 static bool checkComputeStage(std::vector<const void *> datas, const uint32_t numWorkgroups[3],
59                               const uint32_t localSize[3], uint32_t)
60 {
61     return glc::subgroups::checkCompute(datas, numWorkgroups, localSize, 1);
62 }
63 
getOpTypeName(int opType)64 std::string getOpTypeName(int opType)
65 {
66     switch (opType)
67     {
68     default:
69         DE_FATAL("Unsupported op type");
70         return "";
71     case OPTYPE_CLUSTERED_ADD:
72         return "subgroupClusteredAdd";
73     case OPTYPE_CLUSTERED_MUL:
74         return "subgroupClusteredMul";
75     case OPTYPE_CLUSTERED_MIN:
76         return "subgroupClusteredMin";
77     case OPTYPE_CLUSTERED_MAX:
78         return "subgroupClusteredMax";
79     case OPTYPE_CLUSTERED_AND:
80         return "subgroupClusteredAnd";
81     case OPTYPE_CLUSTERED_OR:
82         return "subgroupClusteredOr";
83     case OPTYPE_CLUSTERED_XOR:
84         return "subgroupClusteredXor";
85     }
86 }
87 
getOpTypeOperation(int opType,Format format,std::string lhs,std::string rhs)88 std::string getOpTypeOperation(int opType, Format format, std::string lhs, std::string rhs)
89 {
90     switch (opType)
91     {
92     default:
93         DE_FATAL("Unsupported op type");
94         return "";
95     case OPTYPE_CLUSTERED_ADD:
96         return lhs + " + " + rhs;
97     case OPTYPE_CLUSTERED_MUL:
98         return lhs + " * " + rhs;
99     case OPTYPE_CLUSTERED_MIN:
100         switch (format)
101         {
102         default:
103             return "min(" + lhs + ", " + rhs + ")";
104         case FORMAT_R32_SFLOAT:
105         case FORMAT_R64_SFLOAT:
106             return "(isnan(" + lhs + ") ? " + rhs + " : (isnan(" + rhs + ") ? " + lhs + " : min(" + lhs + ", " + rhs +
107                    ")))";
108         case FORMAT_R32G32_SFLOAT:
109         case FORMAT_R32G32B32_SFLOAT:
110         case FORMAT_R32G32B32A32_SFLOAT:
111         case FORMAT_R64G64_SFLOAT:
112         case FORMAT_R64G64B64_SFLOAT:
113         case FORMAT_R64G64B64A64_SFLOAT:
114             return "mix(mix(min(" + lhs + ", " + rhs + "), " + lhs + ", isnan(" + rhs + ")), " + rhs + ", isnan(" +
115                    lhs + "))";
116         }
117     case OPTYPE_CLUSTERED_MAX:
118         switch (format)
119         {
120         default:
121             return "max(" + lhs + ", " + rhs + ")";
122         case FORMAT_R32_SFLOAT:
123         case FORMAT_R64_SFLOAT:
124             return "(isnan(" + lhs + ") ? " + rhs + " : (isnan(" + rhs + ") ? " + lhs + " : max(" + lhs + ", " + rhs +
125                    ")))";
126         case FORMAT_R32G32_SFLOAT:
127         case FORMAT_R32G32B32_SFLOAT:
128         case FORMAT_R32G32B32A32_SFLOAT:
129         case FORMAT_R64G64_SFLOAT:
130         case FORMAT_R64G64B64_SFLOAT:
131         case FORMAT_R64G64B64A64_SFLOAT:
132             return "mix(mix(max(" + lhs + ", " + rhs + "), " + lhs + ", isnan(" + rhs + ")), " + rhs + ", isnan(" +
133                    lhs + "))";
134         }
135     case OPTYPE_CLUSTERED_AND:
136         switch (format)
137         {
138         default:
139             return lhs + " & " + rhs;
140         case FORMAT_R32_BOOL:
141             return lhs + " && " + rhs;
142         case FORMAT_R32G32_BOOL:
143             return "bvec2(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y)";
144         case FORMAT_R32G32B32_BOOL:
145             return "bvec3(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y, " + lhs + ".z && " + rhs +
146                    ".z)";
147         case FORMAT_R32G32B32A32_BOOL:
148             return "bvec4(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y, " + lhs + ".z && " + rhs +
149                    ".z, " + lhs + ".w && " + rhs + ".w)";
150         }
151     case OPTYPE_CLUSTERED_OR:
152         switch (format)
153         {
154         default:
155             return lhs + " | " + rhs;
156         case FORMAT_R32_BOOL:
157             return lhs + " || " + rhs;
158         case FORMAT_R32G32_BOOL:
159             return "bvec2(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y)";
160         case FORMAT_R32G32B32_BOOL:
161             return "bvec3(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y, " + lhs + ".z || " + rhs +
162                    ".z)";
163         case FORMAT_R32G32B32A32_BOOL:
164             return "bvec4(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y, " + lhs + ".z || " + rhs +
165                    ".z, " + lhs + ".w || " + rhs + ".w)";
166         }
167     case OPTYPE_CLUSTERED_XOR:
168         switch (format)
169         {
170         default:
171             return lhs + " ^ " + rhs;
172         case FORMAT_R32_BOOL:
173             return lhs + " ^^ " + rhs;
174         case FORMAT_R32G32_BOOL:
175             return "bvec2(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y)";
176         case FORMAT_R32G32B32_BOOL:
177             return "bvec3(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y, " + lhs + ".z ^^ " + rhs +
178                    ".z)";
179         case FORMAT_R32G32B32A32_BOOL:
180             return "bvec4(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y, " + lhs + ".z ^^ " + rhs +
181                    ".z, " + lhs + ".w ^^ " + rhs + ".w)";
182         }
183     }
184 }
185 
getIdentity(int opType,Format format)186 std::string getIdentity(int opType, Format format)
187 {
188     bool isFloat    = false;
189     bool isInt      = false;
190     bool isUnsigned = false;
191 
192     switch (format)
193     {
194     default:
195         DE_FATAL("Unhandled format!");
196         break;
197     case FORMAT_R32_SINT:
198     case FORMAT_R32G32_SINT:
199     case FORMAT_R32G32B32_SINT:
200     case FORMAT_R32G32B32A32_SINT:
201         isInt = true;
202         break;
203     case FORMAT_R32_UINT:
204     case FORMAT_R32G32_UINT:
205     case FORMAT_R32G32B32_UINT:
206     case FORMAT_R32G32B32A32_UINT:
207         isUnsigned = true;
208         break;
209     case FORMAT_R32_SFLOAT:
210     case FORMAT_R32G32_SFLOAT:
211     case FORMAT_R32G32B32_SFLOAT:
212     case FORMAT_R32G32B32A32_SFLOAT:
213     case FORMAT_R64_SFLOAT:
214     case FORMAT_R64G64_SFLOAT:
215     case FORMAT_R64G64B64_SFLOAT:
216     case FORMAT_R64G64B64A64_SFLOAT:
217         isFloat = true;
218         break;
219     case FORMAT_R32_BOOL:
220     case FORMAT_R32G32_BOOL:
221     case FORMAT_R32G32B32_BOOL:
222     case FORMAT_R32G32B32A32_BOOL:
223         break; // bool types are not anything
224     }
225 
226     switch (opType)
227     {
228     default:
229         DE_FATAL("Unsupported op type");
230         return "";
231     case OPTYPE_CLUSTERED_ADD:
232         return subgroups::getFormatNameForGLSL(format) + "(0)";
233     case OPTYPE_CLUSTERED_MUL:
234         return subgroups::getFormatNameForGLSL(format) + "(1)";
235     case OPTYPE_CLUSTERED_MIN:
236         if (isFloat)
237         {
238             return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0x7f800000))";
239         }
240         else if (isInt)
241         {
242             return subgroups::getFormatNameForGLSL(format) + "(0x7fffffff)";
243         }
244         else if (isUnsigned)
245         {
246             return subgroups::getFormatNameForGLSL(format) + "(0xffffffffu)";
247         }
248         else
249         {
250             DE_FATAL("Unhandled case");
251             return "";
252         }
253     case OPTYPE_CLUSTERED_MAX:
254         if (isFloat)
255         {
256             return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0xff800000))";
257         }
258         else if (isInt)
259         {
260             return subgroups::getFormatNameForGLSL(format) + "(0x80000000)";
261         }
262         else if (isUnsigned)
263         {
264             return subgroups::getFormatNameForGLSL(format) + "(0u)";
265         }
266         else
267         {
268             DE_FATAL("Unhandled case");
269             return "";
270         }
271     case OPTYPE_CLUSTERED_AND:
272         return subgroups::getFormatNameForGLSL(format) + "(~0)";
273     case OPTYPE_CLUSTERED_OR:
274         return subgroups::getFormatNameForGLSL(format) + "(0)";
275     case OPTYPE_CLUSTERED_XOR:
276         return subgroups::getFormatNameForGLSL(format) + "(0)";
277     }
278 }
279 
getCompare(int opType,Format format,std::string lhs,std::string rhs)280 std::string getCompare(int opType, Format format, std::string lhs, std::string rhs)
281 {
282     std::string formatName = subgroups::getFormatNameForGLSL(format);
283     switch (format)
284     {
285     default:
286         return "all(equal(" + lhs + ", " + rhs + "))";
287     case FORMAT_R32_BOOL:
288     case FORMAT_R32_UINT:
289     case FORMAT_R32_SINT:
290         return "(" + lhs + " == " + rhs + ")";
291     case FORMAT_R32_SFLOAT:
292     case FORMAT_R64_SFLOAT:
293         switch (opType)
294         {
295         default:
296             return "(abs(" + lhs + " - " + rhs + ") < 0.00001)";
297         case OPTYPE_CLUSTERED_MIN:
298         case OPTYPE_CLUSTERED_MAX:
299             return "(" + lhs + " == " + rhs + ")";
300         }
301     case FORMAT_R32G32_SFLOAT:
302     case FORMAT_R32G32B32_SFLOAT:
303     case FORMAT_R32G32B32A32_SFLOAT:
304     case FORMAT_R64G64_SFLOAT:
305     case FORMAT_R64G64B64_SFLOAT:
306     case FORMAT_R64G64B64A64_SFLOAT:
307         switch (opType)
308         {
309         default:
310             return "all(lessThan(abs(" + lhs + " - " + rhs + "), " + formatName + "(0.00001)))";
311         case OPTYPE_CLUSTERED_MIN:
312         case OPTYPE_CLUSTERED_MAX:
313             return "all(equal(" + lhs + ", " + rhs + "))";
314         }
315     }
316 }
317 
318 struct CaseDefinition
319 {
320     int opType;
321     ShaderStageFlags shaderStage;
322     Format format;
323 };
324 
getBodySource(CaseDefinition caseDef)325 std::string getBodySource(CaseDefinition caseDef)
326 {
327     std::ostringstream bdy;
328     bdy << "  bool tempResult = true;\n";
329 
330     for (uint32_t i = 1; i <= subgroups::maxSupportedSubgroupSize(); i *= 2)
331     {
332         bdy << "  {\n"
333             << "    const uint clusterSize = " << i << "u;\n"
334             << "    if (clusterSize <= gl_SubgroupSize)\n"
335             << "    {\n"
336             << "      " << subgroups::getFormatNameForGLSL(caseDef.format)
337             << " op = " << getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID], clusterSize);\n"
338             << "      for (uint clusterOffset = 0u; clusterOffset < gl_SubgroupSize; clusterOffset += clusterSize)\n"
339             << "      {\n"
340             << "        " << subgroups::getFormatNameForGLSL(caseDef.format)
341             << " ref = " << getIdentity(caseDef.opType, caseDef.format) << ";\n"
342             << "        for (uint index = clusterOffset; index < (clusterOffset + clusterSize); index++)\n"
343             << "        {\n"
344             << "          if (subgroupBallotBitExtract(mask, index))\n"
345             << "          {\n"
346             << "            ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n"
347             << "          }\n"
348             << "        }\n"
349             << "        if ((clusterOffset <= gl_SubgroupInvocationID) && (gl_SubgroupInvocationID < (clusterOffset + "
350                "clusterSize)))\n"
351             << "        {\n"
352             << "          if (!" << getCompare(caseDef.opType, caseDef.format, "ref", "op") << ")\n"
353             << "          {\n"
354             << "            tempResult = false;\n"
355             << "          }\n"
356             << "        }\n"
357             << "      }\n"
358             << "    }\n"
359             << "  }\n";
360     }
361     return bdy.str();
362 }
363 
initFrameBufferPrograms(SourceCollections & programCollection,CaseDefinition caseDef)364 void initFrameBufferPrograms(SourceCollections &programCollection, CaseDefinition caseDef)
365 {
366     subgroups::setFragmentShaderFrameBuffer(programCollection);
367 
368     if (SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
369         subgroups::setVertexShaderFrameBuffer(programCollection);
370 
371     std::string bdy = getBodySource(caseDef);
372 
373     if (SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
374     {
375         std::ostringstream vertexSrc;
376         vertexSrc << "${VERSION_DECL}\n"
377                   << "#extension GL_KHR_shader_subgroup_clustered: enable\n"
378                   << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
379                   << "layout(location = 0) in highp vec4 in_position;\n"
380                   << "layout(location = 0) out float out_color;\n"
381                   << "layout(binding = 0, std140) uniform Buffer0\n"
382                   << "{\n"
383                   << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data["
384                   << subgroups::maxSupportedSubgroupSize() << "];\n"
385                   << "};\n"
386                   << "\n"
387                   << "void main (void)\n"
388                   << "{\n"
389                   << "  uvec4 mask = subgroupBallot(true);\n"
390                   << bdy << "  out_color = float(tempResult ? 1 : 0);\n"
391                   << "  gl_Position = in_position;\n"
392                   << "  gl_PointSize = 1.0f;\n"
393                   << "}\n";
394         programCollection.add("vert") << glu::VertexSource(vertexSrc.str());
395     }
396     else if (SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
397     {
398         std::ostringstream geometry;
399 
400         geometry << "${VERSION_DECL}\n"
401                  << "#extension GL_KHR_shader_subgroup_clustered: enable\n"
402                  << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
403                  << "layout(points) in;\n"
404                  << "layout(points, max_vertices = 1) out;\n"
405                  << "layout(location = 0) out float out_color;\n"
406                  << "layout(binding = 0, std140) uniform Buffer0\n"
407                  << "{\n"
408                  << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data["
409                  << subgroups::maxSupportedSubgroupSize() << "];\n"
410                  << "};\n"
411                  << "\n"
412                  << "void main (void)\n"
413                  << "{\n"
414                  << "  uvec4 mask = subgroupBallot(true);\n"
415                  << bdy << "  out_color = tempResult ? 1.0 : 0.0;\n"
416                  << "  gl_Position = gl_in[0].gl_Position;\n"
417                  << "  EmitVertex();\n"
418                  << "  EndPrimitive();\n"
419                  << "}\n";
420 
421         programCollection.add("geometry") << glu::GeometrySource(geometry.str());
422     }
423     else if (SHADER_STAGE_TESS_CONTROL_BIT == caseDef.shaderStage)
424     {
425         std::ostringstream controlSource;
426 
427         controlSource << "${VERSION_DECL}\n"
428                       << "#extension GL_KHR_shader_subgroup_clustered: enable\n"
429                       << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
430                       << "layout(vertices = 2) out;\n"
431                       << "layout(location = 0) out float out_color[];\n"
432                       << "layout(binding = 0, std140) uniform Buffer0\n"
433                       << "{\n"
434                       << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data["
435                       << subgroups::maxSupportedSubgroupSize() << "];\n"
436                       << "};\n"
437                       << "\n"
438                       << "void main (void)\n"
439                       << "{\n"
440                       << "  if (gl_InvocationID == 0)\n"
441                       << "  {\n"
442                       << "    gl_TessLevelOuter[0] = 1.0f;\n"
443                       << "    gl_TessLevelOuter[1] = 1.0f;\n"
444                       << "  }\n"
445                       << "  uvec4 mask = subgroupBallot(true);\n"
446                       << bdy << "  out_color[gl_InvocationID] = tempResult ? 1.0 : 0.0;\n"
447                       << "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
448                       << "}\n";
449 
450         programCollection.add("tesc") << glu::TessellationControlSource(controlSource.str());
451         subgroups::setTesEvalShaderFrameBuffer(programCollection);
452     }
453     else if (SHADER_STAGE_TESS_EVALUATION_BIT == caseDef.shaderStage)
454     {
455         std::ostringstream evaluationSource;
456 
457         evaluationSource << "${VERSION_DECL}\n"
458                          << "#extension GL_KHR_shader_subgroup_clustered: enable\n"
459                          << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
460                          << "layout(isolines, equal_spacing, ccw ) in;\n"
461                          << "layout(location = 0) out float out_color;\n"
462                          << "layout(binding = 0, std140) uniform Buffer0\n"
463                          << "{\n"
464                          << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data["
465                          << subgroups::maxSupportedSubgroupSize() << "];\n"
466                          << "};\n"
467                          << "\n"
468                          << "void main (void)\n"
469                          << "{\n"
470                          << "  uvec4 mask = subgroupBallot(true);\n"
471                          << bdy << "  out_color = tempResult ? 1.0 : 0.0;\n"
472                          << "  gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
473                          << "}\n";
474 
475         subgroups::setTesCtrlShaderFrameBuffer(programCollection);
476         programCollection.add("tese") << glu::TessellationEvaluationSource(evaluationSource.str());
477     }
478     else
479     {
480         DE_FATAL("Unsupported shader stage");
481     }
482 }
483 
initPrograms(SourceCollections & programCollection,CaseDefinition caseDef)484 void initPrograms(SourceCollections &programCollection, CaseDefinition caseDef)
485 {
486     std::string bdy = getBodySource(caseDef);
487 
488     if (SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
489     {
490         std::ostringstream src;
491 
492         src << "${VERSION_DECL}\n"
493             << "#extension GL_KHR_shader_subgroup_clustered: enable\n"
494             << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
495             << "layout (${LOCAL_SIZE_X}, ${LOCAL_SIZE_Y}, ${LOCAL_SIZE_Z}) in;\n"
496             << "layout(binding = 0, std430) buffer Buffer0\n"
497             << "{\n"
498             << "  uint result[];\n"
499             << "};\n"
500             << "layout(binding = 1, std430) buffer Buffer1\n"
501             << "{\n"
502             << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
503             << "};\n"
504             << "\n"
505             << "void main (void)\n"
506             << "{\n"
507             << "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
508             << "  highp uint offset = globalSize.x * ((globalSize.y * "
509                "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
510                "gl_GlobalInvocationID.x;\n"
511             << "  uvec4 mask = subgroupBallot(true);\n"
512             << bdy << "  result[offset] = tempResult ? 1u : 0u;\n"
513             << "}\n";
514 
515         programCollection.add("comp") << glu::ComputeSource(src.str());
516     }
517     else
518     {
519         {
520             const string vertex =
521                 "${VERSION_DECL}\n"
522                 "#extension GL_KHR_shader_subgroup_clustered: enable\n"
523                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
524                 "layout(binding = 0, std430) buffer Buffer0\n"
525                 "{\n"
526                 "  uint result[];\n"
527                 "} b0;\n"
528                 "layout(binding = 4, std430) readonly buffer Buffer4\n"
529                 "{\n"
530                 "  " +
531                 subgroups::getFormatNameForGLSL(caseDef.format) +
532                 " data[];\n"
533                 "};\n"
534                 "\n"
535                 "void main (void)\n"
536                 "{\n"
537                 "  uvec4 mask = subgroupBallot(true);\n" +
538                 bdy +
539                 "  b0.result[gl_VertexID] = tempResult ? 1u : 0u;\n"
540                 "  float pixelSize = 2.0f/1024.0f;\n"
541                 "  float pixelPosition = pixelSize/2.0f - 1.0f;\n"
542                 "  gl_Position = vec4(float(gl_VertexID) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
543                 "}\n";
544 
545             programCollection.add("vert") << glu::VertexSource(vertex);
546         }
547 
548         {
549             const string tesc = "${VERSION_DECL}\n"
550                                 "#extension GL_KHR_shader_subgroup_clustered: enable\n"
551                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
552                                 "layout(vertices=1) out;\n"
553                                 "layout(binding = 1, std430) buffer Buffer1\n"
554                                 "{\n"
555                                 "  uint result[];\n"
556                                 "} b1;\n"
557                                 "layout(binding = 4, std430) readonly buffer Buffer4\n"
558                                 "{\n"
559                                 "  " +
560                                 subgroups::getFormatNameForGLSL(caseDef.format) +
561                                 " data[];\n"
562                                 "};\n"
563                                 "\n"
564                                 "void main (void)\n"
565                                 "{\n"
566                                 "  uvec4 mask = subgroupBallot(true);\n" +
567                                 bdy +
568                                 "  b1.result[gl_PrimitiveID] = tempResult ? 1u : 0u;\n"
569                                 "  if (gl_InvocationID == 0)\n"
570                                 "  {\n"
571                                 "    gl_TessLevelOuter[0] = 1.0f;\n"
572                                 "    gl_TessLevelOuter[1] = 1.0f;\n"
573                                 "  }\n"
574                                 "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
575                                 "}\n";
576 
577             programCollection.add("tesc") << glu::TessellationControlSource(tesc);
578         }
579 
580         {
581             const string tese = "${VERSION_DECL}\n"
582                                 "#extension GL_KHR_shader_subgroup_clustered: enable\n"
583                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
584                                 "layout(isolines) in;\n"
585                                 "layout(binding = 2, std430) buffer Buffer2\n"
586                                 "{\n"
587                                 "  uint result[];\n"
588                                 "} b2;\n"
589                                 "layout(binding = 4, std430) readonly buffer Buffer4\n"
590                                 "{\n"
591                                 "  " +
592                                 subgroups::getFormatNameForGLSL(caseDef.format) +
593                                 " data[];\n"
594                                 "};\n"
595                                 "\n"
596                                 "void main (void)\n"
597                                 "{\n"
598                                 "  uvec4 mask = subgroupBallot(true);\n" +
599                                 bdy +
600                                 "  b2.result[gl_PrimitiveID * 2 + int(gl_TessCoord.x + 0.5)] = tempResult ? 1u : 0u;\n"
601                                 "  float pixelSize = 2.0f/1024.0f;\n"
602                                 "  gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
603                                 "}\n";
604             programCollection.add("tese") << glu::TessellationEvaluationSource(tese);
605         }
606 
607         {
608             const string geometry =
609                 // version string added by addGeometryShadersFromTemplate
610                 "#extension GL_KHR_shader_subgroup_clustered: enable\n"
611                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
612                 "layout(${TOPOLOGY}) in;\n"
613                 "layout(points, max_vertices = 1) out;\n"
614                 "layout(binding = 3, std430) buffer Buffer3\n"
615                 "{\n"
616                 "  uint result[];\n"
617                 "} b3;\n"
618                 "layout(binding = 4, std430) readonly buffer Buffer4\n"
619                 "{\n"
620                 "  " +
621                 subgroups::getFormatNameForGLSL(caseDef.format) +
622                 " data[];\n"
623                 "};\n"
624                 "\n"
625                 "void main (void)\n"
626                 "{\n"
627                 "  uvec4 mask = subgroupBallot(true);\n" +
628                 bdy +
629                 "  b3.result[gl_PrimitiveIDIn] = tempResult ? 1u : 0u;\n"
630                 "  gl_Position = gl_in[0].gl_Position;\n"
631                 "  EmitVertex();\n"
632                 "  EndPrimitive();\n"
633                 "}\n";
634             subgroups::addGeometryShadersFromTemplate(geometry, programCollection);
635         }
636 
637         {
638             const string fragment = "${VERSION_DECL}\n"
639                                     "#extension GL_KHR_shader_subgroup_clustered: enable\n"
640                                     "#extension GL_KHR_shader_subgroup_ballot: enable\n"
641                                     "precision highp int;\n"
642                                     "precision highp float;\n"
643                                     "layout(location = 0) out uint result;\n"
644                                     "layout(binding = 4, std430) readonly buffer Buffer4\n"
645                                     "{\n"
646                                     "  " +
647                                     subgroups::getFormatNameForGLSL(caseDef.format) +
648                                     " data[];\n"
649                                     "};\n"
650                                     "void main (void)\n"
651                                     "{\n"
652                                     "  uvec4 mask = subgroupBallot(true);\n" +
653                                     bdy +
654                                     "  result = tempResult ? 1u : 0u;\n"
655                                     "}\n";
656             programCollection.add("fragment") << glu::FragmentSource(fragment);
657         }
658 
659         subgroups::addNoSubgroupShader(programCollection);
660     }
661 }
662 
supportedCheck(Context & context,CaseDefinition caseDef)663 void supportedCheck(Context &context, CaseDefinition caseDef)
664 {
665     if (!subgroups::isSubgroupSupported(context))
666         TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
667 
668     if (!subgroups::isSubgroupFeatureSupportedForDevice(context, SUBGROUP_FEATURE_CLUSTERED_BIT))
669         TCU_THROW(NotSupportedError, "Device does not support subgroup clustered operations");
670 
671     if (subgroups::isDoubleFormat(caseDef.format) && !subgroups::isDoubleSupportedForDevice(context))
672     {
673         TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
674     }
675 }
676 
noSSBOtest(Context & context,const CaseDefinition caseDef)677 tcu::TestStatus noSSBOtest(Context &context, const CaseDefinition caseDef)
678 {
679     if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
680     {
681         if (subgroups::areSubgroupOperationsRequiredForStage(caseDef.shaderStage))
682         {
683             return tcu::TestStatus::fail("Shader stage " + subgroups::getShaderStageName(caseDef.shaderStage) +
684                                          " is required to support subgroup operations!");
685         }
686         else
687         {
688             TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
689         }
690     }
691 
692     subgroups::SSBOData inputData;
693     inputData.format         = caseDef.format;
694     inputData.layout         = subgroups::SSBOData::LayoutStd140;
695     inputData.numElements    = subgroups::maxSupportedSubgroupSize();
696     inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
697     inputData.binding        = 0u;
698 
699     if (SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
700         return subgroups::makeVertexFrameBufferTest(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
701     else if (SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
702         return subgroups::makeGeometryFrameBufferTest(context, FORMAT_R32_UINT, &inputData, 1,
703                                                       checkVertexPipelineStages);
704     else if (SHADER_STAGE_TESS_CONTROL_BIT == caseDef.shaderStage)
705         return subgroups::makeTessellationEvaluationFrameBufferTest(
706             context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, SHADER_STAGE_TESS_CONTROL_BIT);
707     else if (SHADER_STAGE_TESS_EVALUATION_BIT == caseDef.shaderStage)
708         return subgroups::makeTessellationEvaluationFrameBufferTest(
709             context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, SHADER_STAGE_TESS_EVALUATION_BIT);
710     else
711         TCU_THROW(InternalError, "Unhandled shader stage");
712 }
713 
test(Context & context,const CaseDefinition caseDef)714 tcu::TestStatus test(Context &context, const CaseDefinition caseDef)
715 {
716     if (SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
717     {
718         if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
719         {
720             return tcu::TestStatus::fail("Shader stage " + subgroups::getShaderStageName(caseDef.shaderStage) +
721                                          " is required to support subgroup operations!");
722         }
723         subgroups::SSBOData inputData;
724         inputData.format         = caseDef.format;
725         inputData.layout         = subgroups::SSBOData::LayoutStd430;
726         inputData.numElements    = subgroups::maxSupportedSubgroupSize();
727         inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
728         inputData.binding        = 1u;
729 
730         return subgroups::makeComputeTest(context, FORMAT_R32_UINT, &inputData, 1, checkComputeStage);
731     }
732     else
733     {
734         int supportedStages = context.getDeqpContext().getContextInfo().getInt(GL_SUBGROUP_SUPPORTED_STAGES_KHR);
735 
736         ShaderStageFlags stages = (ShaderStageFlags)(caseDef.shaderStage & supportedStages);
737 
738         if (SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
739         {
740             if ((stages & SHADER_STAGE_FRAGMENT_BIT) == 0)
741                 TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
742             else
743                 stages = SHADER_STAGE_FRAGMENT_BIT;
744         }
745 
746         if ((ShaderStageFlags)0u == stages)
747             TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
748 
749         subgroups::SSBOData inputData;
750         inputData.format         = caseDef.format;
751         inputData.layout         = subgroups::SSBOData::LayoutStd430;
752         inputData.numElements    = subgroups::maxSupportedSubgroupSize();
753         inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
754         inputData.binding        = 4u;
755         inputData.stages         = stages;
756 
757         return subgroups::allStages(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, stages);
758     }
759 }
760 } // namespace
761 
createSubgroupsClusteredTests(deqp::Context & testCtx)762 deqp::TestCaseGroup *createSubgroupsClusteredTests(deqp::Context &testCtx)
763 {
764     de::MovePtr<deqp::TestCaseGroup> graphicGroup(
765         new deqp::TestCaseGroup(testCtx, "graphics", "Subgroup clustered category tests: graphics"));
766     de::MovePtr<deqp::TestCaseGroup> computeGroup(
767         new deqp::TestCaseGroup(testCtx, "compute", "Subgroup clustered category tests: compute"));
768     de::MovePtr<deqp::TestCaseGroup> framebufferGroup(
769         new deqp::TestCaseGroup(testCtx, "framebuffer", "Subgroup clustered category tests: framebuffer"));
770 
771     const ShaderStageFlags stages[] = {SHADER_STAGE_VERTEX_BIT, SHADER_STAGE_TESS_EVALUATION_BIT,
772                                        SHADER_STAGE_TESS_CONTROL_BIT, SHADER_STAGE_GEOMETRY_BIT};
773 
774     const Format formats[] = {
775         FORMAT_R32_SINT,   FORMAT_R32G32_SINT,   FORMAT_R32G32B32_SINT,   FORMAT_R32G32B32A32_SINT,
776         FORMAT_R32_UINT,   FORMAT_R32G32_UINT,   FORMAT_R32G32B32_UINT,   FORMAT_R32G32B32A32_UINT,
777         FORMAT_R32_SFLOAT, FORMAT_R32G32_SFLOAT, FORMAT_R32G32B32_SFLOAT, FORMAT_R32G32B32A32_SFLOAT,
778         FORMAT_R64_SFLOAT, FORMAT_R64G64_SFLOAT, FORMAT_R64G64B64_SFLOAT, FORMAT_R64G64B64A64_SFLOAT,
779         FORMAT_R32_BOOL,   FORMAT_R32G32_BOOL,   FORMAT_R32G32B32_BOOL,   FORMAT_R32G32B32A32_BOOL,
780     };
781 
782     for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
783     {
784         const Format format = formats[formatIndex];
785 
786         for (int opTypeIndex = 0; opTypeIndex < OPTYPE_CLUSTERED_LAST; ++opTypeIndex)
787         {
788             bool isBool  = false;
789             bool isFloat = false;
790 
791             switch (format)
792             {
793             default:
794                 break;
795             case FORMAT_R32_SFLOAT:
796             case FORMAT_R32G32_SFLOAT:
797             case FORMAT_R32G32B32_SFLOAT:
798             case FORMAT_R32G32B32A32_SFLOAT:
799             case FORMAT_R64_SFLOAT:
800             case FORMAT_R64G64_SFLOAT:
801             case FORMAT_R64G64B64_SFLOAT:
802             case FORMAT_R64G64B64A64_SFLOAT:
803                 isFloat = true;
804                 break;
805             case FORMAT_R32_BOOL:
806             case FORMAT_R32G32_BOOL:
807             case FORMAT_R32G32B32_BOOL:
808             case FORMAT_R32G32B32A32_BOOL:
809                 isBool = true;
810                 break;
811             }
812 
813             bool isBitwiseOp = false;
814 
815             switch (opTypeIndex)
816             {
817             default:
818                 break;
819             case OPTYPE_CLUSTERED_AND:
820             case OPTYPE_CLUSTERED_OR:
821             case OPTYPE_CLUSTERED_XOR:
822                 isBitwiseOp = true;
823                 break;
824             }
825 
826             if (isFloat && isBitwiseOp)
827             {
828                 // Skip float with bitwise category.
829                 continue;
830             }
831 
832             if (isBool && !isBitwiseOp)
833             {
834                 // Skip bool when its not the bitwise category.
835                 continue;
836             }
837 
838             const std::string name =
839                 de::toLower(getOpTypeName(opTypeIndex)) + "_" + subgroups::getFormatNameForGLSL(format);
840 
841             {
842                 const CaseDefinition caseDef = {opTypeIndex, SHADER_STAGE_COMPUTE_BIT, format};
843                 SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(
844                     computeGroup.get(), name, "", supportedCheck, initPrograms, test, caseDef);
845             }
846 
847             {
848                 const CaseDefinition caseDef = {opTypeIndex, SHADER_STAGE_ALL_GRAPHICS, format};
849                 SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(
850                     graphicGroup.get(), name, "", supportedCheck, initPrograms, test, caseDef);
851             }
852 
853             for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
854             {
855                 const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex], format};
856                 SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(
857                     framebufferGroup.get(), name + "_" + getShaderStageName(caseDef.shaderStage), "", supportedCheck,
858                     initFrameBufferPrograms, noSSBOtest, caseDef);
859             }
860         }
861     }
862     de::MovePtr<deqp::TestCaseGroup> group(
863         new deqp::TestCaseGroup(testCtx, "clustered", "Subgroup clustered category tests"));
864 
865     group->addChild(graphicGroup.release());
866     group->addChild(computeGroup.release());
867     group->addChild(framebufferGroup.release());
868 
869     return group.release();
870 }
871 
872 } // namespace subgroups
873 } // namespace glc
874