xref: /aosp_15_r20/external/deqp/external/vulkancts/modules/vulkan/subgroups/vktSubgroupsBallotTests.cpp (revision 35238bce31c2a825756842865a792f8cf7f89930)
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2019 The Khronos Group Inc.
6  * Copyright (c) 2019 Google Inc.
7  * Copyright (c) 2017 Codeplay Software Ltd.
8  *
9  * Licensed under the Apache License, Version 2.0 (the "License");
10  * you may not use this file except in compliance with the License.
11  * You may obtain a copy of the License at
12  *
13  *      http://www.apache.org/licenses/LICENSE-2.0
14  *
15  * Unless required by applicable law or agreed to in writing, software
16  * distributed under the License is distributed on an "AS IS" BASIS,
17  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18  * See the License for the specific language governing permissions and
19  * limitations under the License.
20  *
21  */ /*!
22  * \file
23  * \brief Subgroups Tests
24  */ /*--------------------------------------------------------------------*/
25 
26 #include "vktSubgroupsBallotTests.hpp"
27 #include "vktSubgroupsTestsUtils.hpp"
28 
29 #include <string>
30 #include <vector>
31 
32 using namespace tcu;
33 using namespace std;
34 using namespace vk;
35 using namespace vkt;
36 
37 namespace
38 {
39 struct CaseDefinition
40 {
41     VkShaderStageFlags shaderStage;
42     de::SharedPtr<bool> geometryPointSizeSupported;
43     bool extShaderSubGroupBallotTests;
44     bool requiredSubgroupSize;
45 };
46 
checkVertexPipelineStages(const void * internalData,vector<const void * > datas,uint32_t width,uint32_t)47 static bool checkVertexPipelineStages(const void *internalData, vector<const void *> datas, uint32_t width, uint32_t)
48 {
49     DE_UNREF(internalData);
50 
51     return subgroups::check(datas, width, 0x7);
52 }
53 
checkComputeOrMesh(const void * internalData,vector<const void * > datas,const uint32_t numWorkgroups[3],const uint32_t localSize[3],uint32_t)54 static bool checkComputeOrMesh(const void *internalData, vector<const void *> datas, const uint32_t numWorkgroups[3],
55                                const uint32_t localSize[3], uint32_t)
56 {
57     DE_UNREF(internalData);
58 
59     return subgroups::checkComputeOrMesh(datas, numWorkgroups, localSize, 0x7);
60 }
61 
initFrameBufferPrograms(SourceCollections & programCollection,CaseDefinition caseDef)62 void initFrameBufferPrograms(SourceCollections &programCollection, CaseDefinition caseDef)
63 {
64     const SpirVAsmBuildOptions buildOptionsSpr(programCollection.usedVulkanVersion, SPIRV_VERSION_1_3);
65     const string extensionHeader =
66         (caseDef.extShaderSubGroupBallotTests ? "OpExtension \"SPV_KHR_shader_ballot\"\n" : "");
67     const string capabilityBallotHeader =
68         (caseDef.extShaderSubGroupBallotTests ? "OpCapability SubgroupBallotKHR\n" :
69                                                 "OpCapability GroupNonUniformBallot\n");
70     const string subgroupSizeStr = de::toString(subgroups::maxSupportedSubgroupSize());
71 
72     subgroups::setFragmentShaderFrameBuffer(programCollection);
73 
74     if (VK_SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
75         subgroups::setVertexShaderFrameBuffer(programCollection);
76 
77     if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
78     {
79         /*
80             "#extension GL_KHR_shader_subgroup_ballot: enable\n"
81             "layout(location = 0) in highp vec4 in_position;\n"
82             "layout(location = 0) out float out_color;\n"
83             "layout(set = 0, binding = 0) uniform Buffer1\n"
84             "{\n"
85             "  uint data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
86             "};\n"
87             "\n"
88             "void main (void)\n"
89             "{\n"
90             "  uint tempResult = 0;\n"
91             "  tempResult |= !bool(uvec4(0) == subgroupBallot(true)) ? 0x1 : 0;\n"
92             "  bool bData = data[gl_SubgroupInvocationID] != 0;\n"
93             "  tempResult |= !bool(uvec4(0) == subgroupBallot(bData)) ? 0x2 : 0;\n"
94             "  tempResult |= uvec4(0) == subgroupBallot(false) ? 0x4 : 0;\n"
95             "  out_color = float(tempResult);\n"
96             "  gl_Position = in_position;\n"
97             "  gl_PointSize = 1.0f;\n"
98             "}\n";
99         */
100         const string vertex = "; SPIR-V\n"
101                               "; Version: 1.3\n"
102                               "; Generator: Khronos Glslang Reference Front End; 2\n"
103                               "; Bound: 76\n"
104                               "; Schema: 0\n"
105                               "OpCapability Shader\n"
106                               "OpCapability GroupNonUniform\n" +
107                               capabilityBallotHeader + extensionHeader +
108                               "%1 = OpExtInstImport \"GLSL.std.450\"\n"
109                               "OpMemoryModel Logical GLSL450\n"
110                               "OpEntryPoint Vertex %4 \"main\" %35 %62 %70 %72\n"
111                               "OpDecorate %30 ArrayStride 16\n"
112                               "OpMemberDecorate %31 0 Offset 0\n"
113                               "OpDecorate %31 Block\n"
114                               "OpDecorate %33 DescriptorSet 0\n"
115                               "OpDecorate %33 Binding 0\n"
116                               "OpDecorate %35 RelaxedPrecision\n"
117                               "OpDecorate %35 BuiltIn SubgroupLocalInvocationId\n"
118                               "OpDecorate %36 RelaxedPrecision\n"
119                               "OpDecorate %62 Location 0\n"
120                               "OpMemberDecorate %68 0 BuiltIn Position\n"
121                               "OpMemberDecorate %68 1 BuiltIn PointSize\n"
122                               "OpMemberDecorate %68 2 BuiltIn ClipDistance\n"
123                               "OpMemberDecorate %68 3 BuiltIn CullDistance\n"
124                               "OpDecorate %68 Block\n"
125                               "OpDecorate %72 Location 0\n"
126                               "%2 = OpTypeVoid\n"
127                               "%3 = OpTypeFunction %2\n"
128                               "%6 = OpTypeInt 32 0\n"
129                               "%7 = OpTypePointer Function %6\n"
130                               "%9 = OpConstant %6 0\n"
131                               "%10 = OpTypeVector %6 4\n"
132                               "%11 = OpConstantComposite %10 %9 %9 %9 %9\n"
133                               "%12 = OpTypeBool\n"
134                               "%13 = OpConstantTrue %12\n"
135                               "%14 = OpConstant %6 3\n"
136                               "%16 = OpTypeVector %12 4\n"
137                               "%20 = OpTypeInt 32 1\n"
138                               "%21 = OpConstant %20 1\n"
139                               "%22 = OpConstant %20 0\n"
140                               "%27 = OpTypePointer Function %12\n"
141                               "%29 = OpConstant %6 " +
142                               subgroupSizeStr +
143                               "\n"
144                               "%30 = OpTypeArray %6 %29\n"
145                               "%31 = OpTypeStruct %30\n"
146                               "%32 = OpTypePointer Uniform %31\n"
147                               "%33 = OpVariable %32 Uniform\n"
148                               "%34 = OpTypePointer Input %6\n"
149                               "%35 = OpVariable %34 Input\n"
150                               "%37 = OpTypePointer Uniform %6\n"
151                               "%46 = OpConstant %20 2\n"
152                               "%51 = OpConstantFalse %12\n"
153                               "%55 = OpConstant %20 4\n"
154                               "%60 = OpTypeFloat 32\n"
155                               "%61 = OpTypePointer Output %60\n"
156                               "%62 = OpVariable %61 Output\n"
157                               "%65 = OpTypeVector %60 4\n"
158                               "%66 = OpConstant %6 1\n"
159                               "%67 = OpTypeArray %60 %66\n"
160                               "%68 = OpTypeStruct %65 %60 %67 %67\n"
161                               "%69 = OpTypePointer Output %68\n"
162                               "%70 = OpVariable %69 Output\n"
163                               "%71 = OpTypePointer Input %65\n"
164                               "%72 = OpVariable %71 Input\n"
165                               "%74 = OpTypePointer Output %65\n"
166                               "%76 = OpConstant %60 1\n"
167                               "%4 = OpFunction %2 None %3\n"
168                               "%5 = OpLabel\n"
169                               "%8 = OpVariable %7 Function\n"
170                               "%28 = OpVariable %27 Function\n"
171                               "OpStore %8 %9\n"
172                               "%15 = " +
173                               (caseDef.extShaderSubGroupBallotTests ? "OpSubgroupBallotKHR %10 %13" :
174                                                                       "OpGroupNonUniformBallot %10 %14 %13") +
175                               "\n"
176                               "%17 = OpIEqual %16 %11 %15\n"
177                               "%18 = OpAll %12 %17\n"
178                               "%19 = OpLogicalNot %12 %18\n"
179                               "%23 = OpSelect %20 %19 %21 %22\n"
180                               "%24 = OpBitcast %6 %23\n"
181                               "%25 = OpLoad %6 %8\n"
182                               "%26 = OpBitwiseOr %6 %25 %24\n"
183                               "OpStore %8 %26\n"
184                               "%36 = OpLoad %6 %35\n"
185                               "%38 = OpAccessChain %37 %33 %22 %36\n"
186                               "%39 = OpLoad %6 %38\n"
187                               "%40 = OpINotEqual %12 %39 %9\n"
188                               "OpStore %28 %40\n"
189                               "%41 = OpLoad %12 %28\n"
190                               "%42 = " +
191                               (caseDef.extShaderSubGroupBallotTests ? "OpSubgroupBallotKHR %10 %41" :
192                                                                       "OpGroupNonUniformBallot %10 %14 %41") +
193                               "\n"
194                               "%43 = OpIEqual %16 %11 %42\n"
195                               "%44 = OpAll %12 %43\n"
196                               "%45 = OpLogicalNot %12 %44\n"
197                               "%47 = OpSelect %20 %45 %46 %22\n"
198                               "%48 = OpBitcast %6 %47\n"
199                               "%49 = OpLoad %6 %8\n"
200                               "%50 = OpBitwiseOr %6 %49 %48\n"
201                               "OpStore %8 %50\n"
202                               "%52 = " +
203                               (caseDef.extShaderSubGroupBallotTests ? "OpSubgroupBallotKHR %10 %51" :
204                                                                       "OpGroupNonUniformBallot %10 %14 %51") +
205                               "\n"
206                               "%53 = OpIEqual %16 %11 %52\n"
207                               "%54 = OpAll %12 %53\n"
208                               "%56 = OpSelect %20 %54 %55 %22\n"
209                               "%57 = OpBitcast %6 %56\n"
210                               "%58 = OpLoad %6 %8\n"
211                               "%59 = OpBitwiseOr %6 %58 %57\n"
212                               "OpStore %8 %59\n"
213                               "%63 = OpLoad %6 %8\n"
214                               "%64 = OpConvertUToF %60 %63\n"
215                               "OpStore %62 %64\n"
216                               "%73 = OpLoad %65 %72\n"
217                               "%75 = OpAccessChain %74 %70 %22\n"
218                               "OpStore %75 %73\n"
219                               "%77 = OpAccessChain %61 %70 %21\n"
220                               "OpStore %77 %76\n"
221                               "OpReturn\n"
222                               "OpFunctionEnd\n";
223         programCollection.spirvAsmSources.add("vert") << vertex << buildOptionsSpr;
224     }
225     else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
226     {
227         /*
228             "#extension GL_KHR_shader_subgroup_ballot: enable\n"
229             "layout(points) in;\n"
230             "layout(points, max_vertices = 1) out;\n"
231             "layout(location = 0) out float out_color;\n"
232             "layout(set = 0, binding = 0) uniform Buffer1\n"
233             "{\n"
234             "  uint data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
235             "};\n"
236             "\n"
237             "void main (void)\n"
238             "{\n"
239             "  uint tempResult = 0;\n"
240             "  tempResult |= !bool(uvec4(0) == subgroupBallot(true)) ? 0x1 : 0;\n"
241             "  bool bData = data[gl_SubgroupInvocationID] != 0;\n"
242             "  tempResult |= !bool(uvec4(0) == subgroupBallot(bData)) ? 0x2 : 0;\n"
243             "  tempResult |= uvec4(0) == subgroupBallot(false) ? 0x4 : 0;\n"
244             "  out_color = float(tempResult);\n"
245             "  gl_Position = gl_in[0].gl_Position;\n"
246             "  gl_PointSize = gl_in[0].gl_PointSize;\n"
247             "  EmitVertex();\n"
248             "  EndPrimitive();\n"
249             "}\n";
250         */
251         ostringstream geometry;
252 
253         geometry << "; SPIR-V\n"
254                  << "; Version: 1.3\n"
255                  << "; Generator: Khronos Glslang Reference Front End; 2\n"
256                  << "; Bound: 80\n"
257                  << "; Schema: 0\n"
258                  << "OpCapability Geometry\n"
259                  << (*caseDef.geometryPointSizeSupported ? "OpCapability GeometryPointSize\n" : "")
260                  << "OpCapability GroupNonUniform\n"
261                  << capabilityBallotHeader.c_str() << extensionHeader.c_str()
262                  << "%1 = OpExtInstImport \"GLSL.std.450\"\n"
263                  << "OpMemoryModel Logical GLSL450\n"
264                  << "OpEntryPoint Geometry %4 \"main\" %35 %62 %70 %74\n"
265                  << "OpExecutionMode %4 InputPoints\n"
266                  << "OpExecutionMode %4 Invocations 1\n"
267                  << "OpExecutionMode %4 OutputPoints\n"
268                  << "OpExecutionMode %4 OutputVertices 1\n"
269                  << "OpDecorate %30 ArrayStride 16\n"
270                  << "OpMemberDecorate %31 0 Offset 0\n"
271                  << "OpDecorate %31 Block\n"
272                  << "OpDecorate %33 DescriptorSet 0\n"
273                  << "OpDecorate %33 Binding 0\n"
274                  << "OpDecorate %35 RelaxedPrecision\n"
275                  << "OpDecorate %35 BuiltIn SubgroupLocalInvocationId\n"
276                  << "OpDecorate %36 RelaxedPrecision\n"
277                  << "OpDecorate %62 Location 0\n"
278                  << "OpMemberDecorate %68 0 BuiltIn Position\n"
279                  << "OpMemberDecorate %68 1 BuiltIn PointSize\n"
280                  << "OpMemberDecorate %68 2 BuiltIn ClipDistance\n"
281                  << "OpMemberDecorate %68 3 BuiltIn CullDistance\n"
282                  << "OpDecorate %68 Block\n"
283                  << "OpMemberDecorate %71 0 BuiltIn Position\n"
284                  << "OpMemberDecorate %71 1 BuiltIn PointSize\n"
285                  << "OpMemberDecorate %71 2 BuiltIn ClipDistance\n"
286                  << "OpMemberDecorate %71 3 BuiltIn CullDistance\n"
287                  << "OpDecorate %71 Block\n"
288                  << "%2 = OpTypeVoid\n"
289                  << "%3 = OpTypeFunction %2\n"
290                  << "%6 = OpTypeInt 32 0\n"
291                  << "%7 = OpTypePointer Function %6\n"
292                  << "%9 = OpConstant %6 0\n"
293                  << "%10 = OpTypeVector %6 4\n"
294                  << "%11 = OpConstantComposite %10 %9 %9 %9 %9\n"
295                  << "%12 = OpTypeBool\n"
296                  << "%13 = OpConstantTrue %12\n"
297                  << "%14 = OpConstant %6 3\n"
298                  << "%16 = OpTypeVector %12 4\n"
299                  << "%20 = OpTypeInt 32 1\n"
300                  << "%21 = OpConstant %20 1\n"
301                  << "%22 = OpConstant %20 0\n"
302                  << "%27 = OpTypePointer Function %12\n"
303                  << "%29 = OpConstant %6 " << subgroupSizeStr << "\n"
304                  << "%30 = OpTypeArray %6 %29\n"
305                  << "%31 = OpTypeStruct %30\n"
306                  << "%32 = OpTypePointer Uniform %31\n"
307                  << "%33 = OpVariable %32 Uniform\n"
308                  << "%34 = OpTypePointer Input %6\n"
309                  << "%35 = OpVariable %34 Input\n"
310                  << "%37 = OpTypePointer Uniform %6\n"
311                  << "%46 = OpConstant %20 2\n"
312                  << "%51 = OpConstantFalse %12\n"
313                  << "%55 = OpConstant %20 4\n"
314                  << "%60 = OpTypeFloat 32\n"
315                  << "%61 = OpTypePointer Output %60\n"
316                  << "%62 = OpVariable %61 Output\n"
317                  << "%65 = OpTypeVector %60 4\n"
318                  << "%66 = OpConstant %6 1\n"
319                  << "%67 = OpTypeArray %60 %66\n"
320                  << "%68 = OpTypeStruct %65 %60 %67 %67\n"
321                  << "%69 = OpTypePointer Output %68\n"
322                  << "%70 = OpVariable %69 Output\n"
323                  << "%71 = OpTypeStruct %65 %60 %67 %67\n"
324                  << "%72 = OpTypeArray %71 %66\n"
325                  << "%73 = OpTypePointer Input %72\n"
326                  << "%74 = OpVariable %73 Input\n"
327                  << "%75 = OpTypePointer Input %65\n"
328                  << "%78 = OpTypePointer Output %65\n"
329                  << (*caseDef.geometryPointSizeSupported ? "%80 = OpTypePointer Input %60\n"
330                                                            "%81 = OpTypePointer Output %60\n" :
331                                                            "")
332                  << "%4 = OpFunction %2 None %3\n"
333                  << "%5 = OpLabel\n"
334                  << "%8 = OpVariable %7 Function\n"
335                  << "%28 = OpVariable %27 Function\n"
336                  << "OpStore %8 %9\n"
337                  << "%15 = "
338                  << (caseDef.extShaderSubGroupBallotTests ? "OpSubgroupBallotKHR %10 %13" :
339                                                             "OpGroupNonUniformBallot %10 %14 %13")
340                  << "\n"
341                  << "%17 = OpIEqual %16 %11 %15\n"
342                  << "%18 = OpAll %12 %17\n"
343                  << "%19 = OpLogicalNot %12 %18\n"
344                  << "%23 = OpSelect %20 %19 %21 %22\n"
345                  << "%24 = OpBitcast %6 %23\n"
346                  << "%25 = OpLoad %6 %8\n"
347                  << "%26 = OpBitwiseOr %6 %25 %24\n"
348                  << "OpStore %8 %26\n"
349                  << "%36 = OpLoad %6 %35\n"
350                  << "%38 = OpAccessChain %37 %33 %22 %36\n"
351                  << "%39 = OpLoad %6 %38\n"
352                  << "%40 = OpINotEqual %12 %39 %9\n"
353                  << "OpStore %28 %40\n"
354                  << "%41 = OpLoad %12 %28\n"
355                  << "%42 = "
356                  << (caseDef.extShaderSubGroupBallotTests ? "OpSubgroupBallotKHR %10 %41" :
357                                                             "OpGroupNonUniformBallot %10 %14 %41")
358                  << "\n"
359                  << "%43 = OpIEqual %16 %11 %42\n"
360                  << "%44 = OpAll %12 %43\n"
361                  << "%45 = OpLogicalNot %12 %44\n"
362                  << "%47 = OpSelect %20 %45 %46 %22\n"
363                  << "%48 = OpBitcast %6 %47\n"
364                  << "%49 = OpLoad %6 %8\n"
365                  << "%50 = OpBitwiseOr %6 %49 %48\n"
366                  << "OpStore %8 %50\n"
367                  << "%52 = "
368                  << (caseDef.extShaderSubGroupBallotTests ? "OpSubgroupBallotKHR %10 %51" :
369                                                             "OpGroupNonUniformBallot %10 %14 %51")
370                  << "\n"
371                  << "%53 = OpIEqual %16 %11 %52\n"
372                  << "%54 = OpAll %12 %53\n"
373                  << "%56 = OpSelect %20 %54 %55 %22\n"
374                  << "%57 = OpBitcast %6 %56\n"
375                  << "%58 = OpLoad %6 %8\n"
376                  << "%59 = OpBitwiseOr %6 %58 %57\n"
377                  << "OpStore %8 %59\n"
378                  << "%63 = OpLoad %6 %8\n"
379                  << "%64 = OpConvertUToF %60 %63\n"
380                  << "OpStore %62 %64\n"
381                  << "%76 = OpAccessChain %75 %74 %22 %22\n"
382                  << "%77 = OpLoad %65 %76\n"
383                  << "%79 = OpAccessChain %78 %70 %22\n"
384                  << "OpStore %79 %77\n"
385                  << (*caseDef.geometryPointSizeSupported ? "%82 = OpAccessChain %80 %74 %22 %21\n"
386                                                            "%83 = OpLoad %60 %82\n"
387                                                            "%84 = OpAccessChain %81 %70 %21\n"
388                                                            "OpStore %84 %83\n" :
389                                                            "")
390                  << "OpEmitVertex\n"
391                  << "OpEndPrimitive\n"
392                  << "OpReturn\n"
393                  << "OpFunctionEnd\n";
394         programCollection.spirvAsmSources.add("geometry") << geometry.str() << buildOptionsSpr;
395     }
396     else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
397     {
398         /*
399             "#extension GL_KHR_shader_subgroup_ballot: enable\n"
400             "layout(vertices = 2) out;\n"
401             "layout(location = 0) out float out_color[];\n"
402             "layout(set = 0, binding = 0) uniform Buffer1\n"
403             "{\n"
404             "  uint data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
405             "};\n"
406             "\n"
407             "void main (void)\n"
408             "{\n"
409             "  if (gl_InvocationID == 0)\n"
410               {\n"
411             "    gl_TessLevelOuter[0] = 1.0f;\n"
412             "    gl_TessLevelOuter[1] = 1.0f;\n"
413             "  }\n"
414             "  uint tempResult = 0;\n"
415             "  tempResult |= !bool(uvec4(0) == subgroupBallot(true)) ? 0x1 : 0;\n"
416             "  bool bData = data[gl_SubgroupInvocationID] != 0;\n"
417             "  tempResult |= !bool(uvec4(0) == subgroupBallot(bData)) ? 0x2 : 0;\n"
418             "  tempResult |= uvec4(0) == subgroupBallot(false) ? 0x4 : 0;\n"
419             "  out_color[gl_InvocationID] = float(tempResult);\n"
420             "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
421             "}\n";
422         */
423         const string controlSource = "; SPIR-V\n"
424                                      "; Version: 1.3\n"
425                                      "; Generator: Khronos Glslang Reference Front End; 2\n"
426                                      "; Bound: 102\n"
427                                      "; Schema: 0\n"
428                                      "OpCapability Tessellation\n"
429                                      "OpCapability GroupNonUniform\n" +
430                                      capabilityBallotHeader + extensionHeader +
431                                      "%1 = OpExtInstImport \"GLSL.std.450\"\n"
432                                      "OpMemoryModel Logical GLSL450\n"
433                                      "OpEntryPoint TessellationControl %4 \"main\" %8 %20 %50 %78 %89 %95\n"
434                                      "OpExecutionMode %4 OutputVertices 2\n"
435                                      "OpDecorate %8 BuiltIn InvocationId\n"
436                                      "OpDecorate %20 Patch\n"
437                                      "OpDecorate %20 BuiltIn TessLevelOuter\n"
438                                      "OpDecorate %45 ArrayStride 16\n"
439                                      "OpMemberDecorate %46 0 Offset 0\n"
440                                      "OpDecorate %46 Block\n"
441                                      "OpDecorate %48 DescriptorSet 0\n"
442                                      "OpDecorate %48 Binding 0\n"
443                                      "OpDecorate %50 RelaxedPrecision\n"
444                                      "OpDecorate %50 BuiltIn SubgroupLocalInvocationId\n"
445                                      "OpDecorate %51 RelaxedPrecision\n"
446                                      "OpDecorate %78 Location 0\n"
447                                      "OpMemberDecorate %86 0 BuiltIn Position\n"
448                                      "OpMemberDecorate %86 1 BuiltIn PointSize\n"
449                                      "OpMemberDecorate %86 2 BuiltIn ClipDistance\n"
450                                      "OpMemberDecorate %86 3 BuiltIn CullDistance\n"
451                                      "OpDecorate %86 Block\n"
452                                      "OpMemberDecorate %91 0 BuiltIn Position\n"
453                                      "OpMemberDecorate %91 1 BuiltIn PointSize\n"
454                                      "OpMemberDecorate %91 2 BuiltIn ClipDistance\n"
455                                      "OpMemberDecorate %91 3 BuiltIn CullDistance\n"
456                                      "OpDecorate %91 Block\n"
457                                      "%2 = OpTypeVoid\n"
458                                      "%3 = OpTypeFunction %2\n"
459                                      "%6 = OpTypeInt 32 1\n"
460                                      "%7 = OpTypePointer Input %6\n"
461                                      "%8 = OpVariable %7 Input\n"
462                                      "%10 = OpConstant %6 0\n"
463                                      "%11 = OpTypeBool\n"
464                                      "%15 = OpTypeFloat 32\n"
465                                      "%16 = OpTypeInt 32 0\n"
466                                      "%17 = OpConstant %16 4\n"
467                                      "%18 = OpTypeArray %15 %17\n"
468                                      "%19 = OpTypePointer Output %18\n"
469                                      "%20 = OpVariable %19 Output\n"
470                                      "%21 = OpConstant %15 1\n"
471                                      "%22 = OpTypePointer Output %15\n"
472                                      "%24 = OpConstant %6 1\n"
473                                      "%26 = OpTypePointer Function %16\n"
474                                      "%28 = OpConstant %16 0\n"
475                                      "%29 = OpTypeVector %16 4\n"
476                                      "%30 = OpConstantComposite %29 %28 %28 %28 %28\n"
477                                      "%31 = OpConstantTrue %11\n"
478                                      "%32 = OpConstant %16 3\n"
479                                      "%34 = OpTypeVector %11 4\n"
480                                      "%42 = OpTypePointer Function %11\n"
481                                      "%44 = OpConstant %16 " +
482                                      subgroupSizeStr +
483                                      "\n"
484                                      "%45 = OpTypeArray %16 %44\n"
485                                      "%46 = OpTypeStruct %45\n"
486                                      "%47 = OpTypePointer Uniform %46\n"
487                                      "%48 = OpVariable %47 Uniform\n"
488                                      "%49 = OpTypePointer Input %16\n"
489                                      "%50 = OpVariable %49 Input\n"
490                                      "%52 = OpTypePointer Uniform %16\n"
491                                      "%61 = OpConstant %6 2\n"
492                                      "%66 = OpConstantFalse %11\n"
493                                      "%70 = OpConstant %6 4\n"
494                                      "%75 = OpConstant %16 2\n"
495                                      "%76 = OpTypeArray %15 %75\n"
496                                      "%77 = OpTypePointer Output %76\n"
497                                      "%78 = OpVariable %77 Output\n"
498                                      "%83 = OpTypeVector %15 4\n"
499                                      "%84 = OpConstant %16 1\n"
500                                      "%85 = OpTypeArray %15 %84\n"
501                                      "%86 = OpTypeStruct %83 %15 %85 %85\n"
502                                      "%87 = OpTypeArray %86 %75\n"
503                                      "%88 = OpTypePointer Output %87\n"
504                                      "%89 = OpVariable %88 Output\n"
505                                      "%91 = OpTypeStruct %83 %15 %85 %85\n"
506                                      "%92 = OpConstant %16 32\n"
507                                      "%93 = OpTypeArray %91 %92\n"
508                                      "%94 = OpTypePointer Input %93\n"
509                                      "%95 = OpVariable %94 Input\n"
510                                      "%97 = OpTypePointer Input %83\n"
511                                      "%100 = OpTypePointer Output %83\n"
512                                      "%4 = OpFunction %2 None %3\n"
513                                      "%5 = OpLabel\n"
514                                      "%27 = OpVariable %26 Function\n"
515                                      "%43 = OpVariable %42 Function\n"
516                                      "%9 = OpLoad %6 %8\n"
517                                      "%12 = OpIEqual %11 %9 %10\n"
518                                      "OpSelectionMerge %14 None\n"
519                                      "OpBranchConditional %12 %13 %14\n"
520                                      "%13 = OpLabel\n"
521                                      "%23 = OpAccessChain %22 %20 %10\n"
522                                      "OpStore %23 %21\n"
523                                      "%25 = OpAccessChain %22 %20 %24\n"
524                                      "OpStore %25 %21\n"
525                                      "OpBranch %14\n"
526                                      "%14 = OpLabel\n"
527                                      "OpStore %27 %28\n"
528                                      "%33 = " +
529                                      (caseDef.extShaderSubGroupBallotTests ? "OpSubgroupBallotKHR %29 %31" :
530                                                                              "OpGroupNonUniformBallot %29 %32 %31") +
531                                      "\n"
532                                      "%35 = OpIEqual %34 %30 %33\n"
533                                      "%36 = OpAll %11 %35\n"
534                                      "%37 = OpLogicalNot %11 %36\n"
535                                      "%38 = OpSelect %6 %37 %24 %10\n"
536                                      "%39 = OpBitcast %16 %38\n"
537                                      "%40 = OpLoad %16 %27\n"
538                                      "%41 = OpBitwiseOr %16 %40 %39\n"
539                                      "OpStore %27 %41\n"
540                                      "%51 = OpLoad %16 %50\n"
541                                      "%53 = OpAccessChain %52 %48 %10 %51\n"
542                                      "%54 = OpLoad %16 %53\n"
543                                      "%55 = OpINotEqual %11 %54 %28\n"
544                                      "OpStore %43 %55\n"
545                                      "%56 = OpLoad %11 %43\n"
546                                      "%57 = " +
547                                      (caseDef.extShaderSubGroupBallotTests ? "OpSubgroupBallotKHR %29 %56" :
548                                                                              "OpGroupNonUniformBallot %29 %32 %56") +
549                                      "\n"
550                                      "%58 = OpIEqual %34 %30 %57\n"
551                                      "%59 = OpAll %11 %58\n"
552                                      "%60 = OpLogicalNot %11 %59\n"
553                                      "%62 = OpSelect %6 %60 %61 %10\n"
554                                      "%63 = OpBitcast %16 %62\n"
555                                      "%64 = OpLoad %16 %27\n"
556                                      "%65 = OpBitwiseOr %16 %64 %63\n"
557                                      "OpStore %27 %65\n"
558                                      "%67 = " +
559                                      (caseDef.extShaderSubGroupBallotTests ? "OpSubgroupBallotKHR %29 %66" :
560                                                                              "OpGroupNonUniformBallot %29 %32 %66") +
561                                      "\n"
562                                      "%68 = OpIEqual %34 %30 %67\n"
563                                      "%69 = OpAll %11 %68\n"
564                                      "%71 = OpSelect %6 %69 %70 %10\n"
565                                      "%72 = OpBitcast %16 %71\n"
566                                      "%73 = OpLoad %16 %27\n"
567                                      "%74 = OpBitwiseOr %16 %73 %72\n"
568                                      "OpStore %27 %74\n"
569                                      "%79 = OpLoad %6 %8\n"
570                                      "%80 = OpLoad %16 %27\n"
571                                      "%81 = OpConvertUToF %15 %80\n"
572                                      "%82 = OpAccessChain %22 %78 %79\n"
573                                      "OpStore %82 %81\n"
574                                      "%90 = OpLoad %6 %8\n"
575                                      "%96 = OpLoad %6 %8\n"
576                                      "%98 = OpAccessChain %97 %95 %96 %10\n"
577                                      "%99 = OpLoad %83 %98\n"
578                                      "%101 = OpAccessChain %100 %89 %90 %10\n"
579                                      "OpStore %101 %99\n"
580                                      "OpReturn\n"
581                                      "OpFunctionEnd\n";
582 
583         programCollection.spirvAsmSources.add("tesc") << controlSource << buildOptionsSpr;
584         subgroups::setTesEvalShaderFrameBuffer(programCollection);
585     }
586     else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
587     {
588         /*
589             "#extension GL_KHR_shader_subgroup_ballot: enable\n"
590             "layout(isolines, equal_spacing, ccw ) in;\n"
591             "layout(location = 0) out float out_color;\n"
592             "layout(set = 0, binding = 0) uniform Buffer1\n"
593             "{\n"
594             "  uint data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
595             "};\n"
596             "\n"
597             "void main (void)\n"
598             "{\n"
599             "  uint tempResult = 0;\n"
600             "  tempResult |= !bool(uvec4(0) == subgroupBallot(true)) ? 0x1 : 0;\n"
601             "  bool bData = data[gl_SubgroupInvocationID] != 0;\n"
602             "  tempResult |= !bool(uvec4(0) == subgroupBallot(bData)) ? 0x2 : 0;\n"
603             "  tempResult |= uvec4(0) == subgroupBallot(false) ? 0x4 : 0;\n"
604             "  out_color = float(tempResult);\n"
605             "  gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
606             "}\n";
607         */
608         const string evaluationSource = "; SPIR-V\n"
609                                         "; Version: 1.3\n"
610                                         "; Generator: Khronos Glslang Reference Front End; 2\n"
611                                         "; Bound: 91\n"
612                                         "; Schema: 0\n"
613                                         "OpCapability Tessellation\n"
614                                         "OpCapability GroupNonUniform\n" +
615                                         capabilityBallotHeader + extensionHeader +
616                                         "%1 = OpExtInstImport \"GLSL.std.450\"\n"
617                                         "OpMemoryModel Logical GLSL450\n"
618                                         "OpEntryPoint TessellationEvaluation %4 \"main\" %35 %62 %70 %75 %83\n"
619                                         "OpExecutionMode %4 Isolines\n"
620                                         "OpExecutionMode %4 SpacingEqual\n"
621                                         "OpExecutionMode %4 VertexOrderCcw\n"
622                                         "OpDecorate %30 ArrayStride 16\n"
623                                         "OpMemberDecorate %31 0 Offset 0\n"
624                                         "OpDecorate %31 Block\n"
625                                         "OpDecorate %33 DescriptorSet 0\n"
626                                         "OpDecorate %33 Binding 0\n"
627                                         "OpDecorate %35 RelaxedPrecision\n"
628                                         "OpDecorate %35 BuiltIn SubgroupLocalInvocationId\n"
629                                         "OpDecorate %36 RelaxedPrecision\n"
630                                         "OpDecorate %62 Location 0\n"
631                                         "OpMemberDecorate %68 0 BuiltIn Position\n"
632                                         "OpMemberDecorate %68 1 BuiltIn PointSize\n"
633                                         "OpMemberDecorate %68 2 BuiltIn ClipDistance\n"
634                                         "OpMemberDecorate %68 3 BuiltIn CullDistance\n"
635                                         "OpDecorate %68 Block\n"
636                                         "OpMemberDecorate %71 0 BuiltIn Position\n"
637                                         "OpMemberDecorate %71 1 BuiltIn PointSize\n"
638                                         "OpMemberDecorate %71 2 BuiltIn ClipDistance\n"
639                                         "OpMemberDecorate %71 3 BuiltIn CullDistance\n"
640                                         "OpDecorate %71 Block\n"
641                                         "OpDecorate %83 BuiltIn TessCoord\n"
642                                         "%2 = OpTypeVoid\n"
643                                         "%3 = OpTypeFunction %2\n"
644                                         "%6 = OpTypeInt 32 0\n"
645                                         "%7 = OpTypePointer Function %6\n"
646                                         "%9 = OpConstant %6 0\n"
647                                         "%10 = OpTypeVector %6 4\n"
648                                         "%11 = OpConstantComposite %10 %9 %9 %9 %9\n"
649                                         "%12 = OpTypeBool\n"
650                                         "%13 = OpConstantTrue %12\n"
651                                         "%14 = OpConstant %6 3\n"
652                                         "%16 = OpTypeVector %12 4\n"
653                                         "%20 = OpTypeInt 32 1\n"
654                                         "%21 = OpConstant %20 1\n"
655                                         "%22 = OpConstant %20 0\n"
656                                         "%27 = OpTypePointer Function %12\n"
657                                         "%29 = OpConstant %6 " +
658                                         subgroupSizeStr +
659                                         "\n"
660                                         "%30 = OpTypeArray %6 %29\n"
661                                         "%31 = OpTypeStruct %30\n"
662                                         "%32 = OpTypePointer Uniform %31\n"
663                                         "%33 = OpVariable %32 Uniform\n"
664                                         "%34 = OpTypePointer Input %6\n"
665                                         "%35 = OpVariable %34 Input\n"
666                                         "%37 = OpTypePointer Uniform %6\n"
667                                         "%46 = OpConstant %20 2\n"
668                                         "%51 = OpConstantFalse %12\n"
669                                         "%55 = OpConstant %20 4\n"
670                                         "%60 = OpTypeFloat 32\n"
671                                         "%61 = OpTypePointer Output %60\n"
672                                         "%62 = OpVariable %61 Output\n"
673                                         "%65 = OpTypeVector %60 4\n"
674                                         "%66 = OpConstant %6 1\n"
675                                         "%67 = OpTypeArray %60 %66\n"
676                                         "%68 = OpTypeStruct %65 %60 %67 %67\n"
677                                         "%69 = OpTypePointer Output %68\n"
678                                         "%70 = OpVariable %69 Output\n"
679                                         "%71 = OpTypeStruct %65 %60 %67 %67\n"
680                                         "%72 = OpConstant %6 32\n"
681                                         "%73 = OpTypeArray %71 %72\n"
682                                         "%74 = OpTypePointer Input %73\n"
683                                         "%75 = OpVariable %74 Input\n"
684                                         "%76 = OpTypePointer Input %65\n"
685                                         "%81 = OpTypeVector %60 3\n"
686                                         "%82 = OpTypePointer Input %81\n"
687                                         "%83 = OpVariable %82 Input\n"
688                                         "%84 = OpTypePointer Input %60\n"
689                                         "%89 = OpTypePointer Output %65\n"
690                                         "%4 = OpFunction %2 None %3\n"
691                                         "%5 = OpLabel\n"
692                                         "%8 = OpVariable %7 Function\n"
693                                         "%28 = OpVariable %27 Function\n"
694                                         "OpStore %8 %9\n"
695                                         "%15 = " +
696                                         (caseDef.extShaderSubGroupBallotTests ? "OpSubgroupBallotKHR %10 %13" :
697                                                                                 "OpGroupNonUniformBallot %10 %14 %13") +
698                                         "\n"
699                                         "%17 = OpIEqual %16 %11 %15\n"
700                                         "%18 = OpAll %12 %17\n"
701                                         "%19 = OpLogicalNot %12 %18\n"
702                                         "%23 = OpSelect %20 %19 %21 %22\n"
703                                         "%24 = OpBitcast %6 %23\n"
704                                         "%25 = OpLoad %6 %8\n"
705                                         "%26 = OpBitwiseOr %6 %25 %24\n"
706                                         "OpStore %8 %26\n"
707                                         "%36 = OpLoad %6 %35\n"
708                                         "%38 = OpAccessChain %37 %33 %22 %36\n"
709                                         "%39 = OpLoad %6 %38\n"
710                                         "%40 = OpINotEqual %12 %39 %9\n"
711                                         "OpStore %28 %40\n"
712                                         "%41 = OpLoad %12 %28\n"
713                                         "%42 = " +
714                                         (caseDef.extShaderSubGroupBallotTests ? "OpSubgroupBallotKHR %10 %41" :
715                                                                                 "OpGroupNonUniformBallot %10 %14 %41") +
716                                         "\n"
717                                         "%43 = OpIEqual %16 %11 %42\n"
718                                         "%44 = OpAll %12 %43\n"
719                                         "%45 = OpLogicalNot %12 %44\n"
720                                         "%47 = OpSelect %20 %45 %46 %22\n"
721                                         "%48 = OpBitcast %6 %47\n"
722                                         "%49 = OpLoad %6 %8\n"
723                                         "%50 = OpBitwiseOr %6 %49 %48\n"
724                                         "OpStore %8 %50\n"
725                                         "%52 = " +
726                                         (caseDef.extShaderSubGroupBallotTests ? "OpSubgroupBallotKHR %10 %51" :
727                                                                                 "OpGroupNonUniformBallot %10 %14 %51") +
728                                         "\n"
729                                         "%53 = OpIEqual %16 %11 %52\n"
730                                         "%54 = OpAll %12 %53\n"
731                                         "%56 = OpSelect %20 %54 %55 %22\n"
732                                         "%57 = OpBitcast %6 %56\n"
733                                         "%58 = OpLoad %6 %8\n"
734                                         "%59 = OpBitwiseOr %6 %58 %57\n"
735                                         "OpStore %8 %59\n"
736                                         "%63 = OpLoad %6 %8\n"
737                                         "%64 = OpConvertUToF %60 %63\n"
738                                         "OpStore %62 %64\n"
739                                         "%77 = OpAccessChain %76 %75 %22 %22\n"
740                                         "%78 = OpLoad %65 %77\n"
741                                         "%79 = OpAccessChain %76 %75 %21 %22\n"
742                                         "%80 = OpLoad %65 %79\n"
743                                         "%85 = OpAccessChain %84 %83 %9\n"
744                                         "%86 = OpLoad %60 %85\n"
745                                         "%87 = OpCompositeConstruct %65 %86 %86 %86 %86\n"
746                                         "%88 = OpExtInst %65 %1 FMix %78 %80 %87\n"
747                                         "%90 = OpAccessChain %89 %70 %22\n"
748                                         "OpStore %90 %88\n"
749                                         "OpReturn\n"
750                                         "OpFunctionEnd\n";
751         subgroups::setTesCtrlShaderFrameBuffer(programCollection);
752         programCollection.spirvAsmSources.add("tese") << evaluationSource << buildOptionsSpr;
753     }
754     else
755     {
756         DE_FATAL("Unsupported shader stage");
757     }
758 }
759 
getExtHeader(const CaseDefinition & caseDef)760 string getExtHeader(const CaseDefinition &caseDef)
761 {
762     return (caseDef.extShaderSubGroupBallotTests ? "#extension GL_ARB_shader_ballot: enable\n"
763                                                    "#extension GL_ARB_gpu_shader_int64: enable\n"
764                                                    "#extension GL_KHR_shader_subgroup_basic: enable\n" :
765                                                    "#extension GL_KHR_shader_subgroup_ballot: enable\n");
766 }
767 
getBodySource(const CaseDefinition & caseDef)768 string getBodySource(const CaseDefinition &caseDef)
769 {
770     const string cmpStr =
771         caseDef.extShaderSubGroupBallotTests ? "uint64_t(0) == ballotARB" : "uvec4(0) == subgroupBallot";
772 
773     if (isAllComputeStages(caseDef.shaderStage))
774     {
775         const string cmpStrB = caseDef.extShaderSubGroupBallotTests ? "ballotARB" : "subgroupBallot";
776 
777         return "  uint tempResult = 0;\n"
778                "  tempResult |= sharedMemoryBallot(true) == " +
779                cmpStrB +
780                "(true) ? 0x1 : 0;\n"
781                "  bool bData = data[gl_SubgroupInvocationID] != 0;\n"
782                "  tempResult |= sharedMemoryBallot(bData) == " +
783                cmpStrB +
784                "(bData) ? 0x2 : 0;\n"
785                "  tempResult |= " +
786                cmpStr +
787                "(false) ? 0x4 : 0;\n"
788                "  tempRes = tempResult;\n";
789     }
790     else
791     {
792         return "  uint tempResult = 0;\n"
793                "  tempResult |= !bool(" +
794                cmpStr +
795                "(true)) ? 0x1 : 0;\n"
796                "  bool bData = data[gl_SubgroupInvocationID] != 0;\n"
797                "  tempResult |= !bool(" +
798                cmpStr +
799                "(bData)) ? 0x2 : 0;\n"
800                "  tempResult |= " +
801                cmpStr +
802                "(false) ? 0x4 : 0;\n"
803                "  tempRes = tempResult;\n";
804     }
805 }
806 
initPrograms(SourceCollections & programCollection,CaseDefinition caseDef)807 void initPrograms(SourceCollections &programCollection, CaseDefinition caseDef)
808 {
809 #ifndef CTS_USES_VULKANSC
810     const bool spirv14required =
811         (isAllRayTracingStages(caseDef.shaderStage) || isAllMeshShadingStages(caseDef.shaderStage));
812 #else
813     const bool spirv14required = false;
814 #endif // CTS_USES_VULKANSC
815     const SpirvVersion spirvVersion = (spirv14required ? SPIRV_VERSION_1_4 : SPIRV_VERSION_1_3);
816     const ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, spirvVersion, 0u, spirv14required);
817     const string extHeader      = getExtHeader(caseDef);
818     const string testSrc        = getBodySource(caseDef);
819     const string testHelper     = !isAllComputeStages(caseDef.shaderStage) ? "" :
820                                   caseDef.extShaderSubGroupBallotTests ? subgroups::getSharedMemoryBallotHelperARB() :
821                                                                          subgroups::getSharedMemoryBallotHelper();
822     const bool pointSizeSupport = *caseDef.geometryPointSizeSupported;
823 
824     subgroups::initStdPrograms(programCollection, buildOptions, caseDef.shaderStage, VK_FORMAT_R32_UINT,
825                                pointSizeSupport, extHeader, testSrc, testHelper);
826 }
827 
supportedCheck(Context & context,CaseDefinition caseDef)828 void supportedCheck(Context &context, CaseDefinition caseDef)
829 {
830     if (!subgroups::isSubgroupSupported(context))
831         TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
832 
833     if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_BALLOT_BIT))
834     {
835         TCU_THROW(NotSupportedError, "Device does not support subgroup ballot operations");
836     }
837 
838     if (caseDef.extShaderSubGroupBallotTests && !context.requireDeviceFunctionality("VK_EXT_shader_subgroup_ballot"))
839     {
840         TCU_THROW(NotSupportedError, "Device does not support VK_EXT_shader_subgroup_ballot extension");
841     }
842 
843     if (caseDef.extShaderSubGroupBallotTests && !subgroups::isInt64SupportedForDevice(context))
844     {
845         TCU_THROW(NotSupportedError, "Device does not support int64 data types");
846     }
847 
848     if (caseDef.requiredSubgroupSize)
849     {
850         context.requireDeviceFunctionality("VK_EXT_subgroup_size_control");
851 
852 #ifndef CTS_USES_VULKANSC
853         const VkPhysicalDeviceSubgroupSizeControlFeatures &subgroupSizeControlFeatures =
854             context.getSubgroupSizeControlFeatures();
855         const VkPhysicalDeviceSubgroupSizeControlProperties &subgroupSizeControlProperties =
856             context.getSubgroupSizeControlProperties();
857 #else
858         const VkPhysicalDeviceSubgroupSizeControlFeaturesEXT &subgroupSizeControlFeatures =
859             context.getSubgroupSizeControlFeaturesEXT();
860         const VkPhysicalDeviceSubgroupSizeControlPropertiesEXT &subgroupSizeControlProperties =
861             context.getSubgroupSizeControlPropertiesEXT();
862 #endif // CTS_USES_VULKANSC
863 
864         if (subgroupSizeControlFeatures.subgroupSizeControl == false)
865             TCU_THROW(NotSupportedError, "Device does not support varying subgroup sizes nor required subgroup size");
866 
867         if (subgroupSizeControlFeatures.computeFullSubgroups == false)
868             TCU_THROW(NotSupportedError, "Device does not support full subgroups in compute shaders");
869 
870         if ((subgroupSizeControlProperties.requiredSubgroupSizeStages & caseDef.shaderStage) != caseDef.shaderStage)
871             TCU_THROW(NotSupportedError, "Required subgroup size is not supported for shader stage");
872     }
873 
874     *caseDef.geometryPointSizeSupported = subgroups::isTessellationAndGeometryPointSizeSupported(context);
875 
876 #ifndef CTS_USES_VULKANSC
877     if (isAllRayTracingStages(caseDef.shaderStage))
878     {
879         context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
880     }
881     else if (isAllMeshShadingStages(caseDef.shaderStage))
882     {
883         context.requireDeviceCoreFeature(DEVICE_CORE_FEATURE_VERTEX_PIPELINE_STORES_AND_ATOMICS);
884         context.requireDeviceFunctionality("VK_EXT_mesh_shader");
885 
886         if ((caseDef.shaderStage & VK_SHADER_STAGE_TASK_BIT_EXT) != 0u)
887         {
888             const auto &features = context.getMeshShaderFeaturesEXT();
889             if (!features.taskShader)
890                 TCU_THROW(NotSupportedError, "Task shaders not supported");
891         }
892     }
893 #endif // CTS_USES_VULKANSC
894 
895     subgroups::supportedCheckShader(context, caseDef.shaderStage);
896 }
897 
noSSBOtest(Context & context,const CaseDefinition caseDef)898 TestStatus noSSBOtest(Context &context, const CaseDefinition caseDef)
899 {
900     const subgroups::SSBOData inputData = {
901         subgroups::SSBOData::InitializeNonZero, //  InputDataInitializeType initializeType;
902         subgroups::SSBOData::LayoutStd140,      //  InputDataLayoutType layout;
903         VK_FORMAT_R32_UINT,                     //  vk::VkFormat format;
904         subgroups::maxSupportedSubgroupSize(),  //  vk::VkDeviceSize numElements;
905         subgroups::SSBOData::BindingUBO,        //  BindingType bindingType;
906     };
907 
908     switch (caseDef.shaderStage)
909     {
910     case VK_SHADER_STAGE_VERTEX_BIT:
911         return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
912                                                     checkVertexPipelineStages);
913     case VK_SHADER_STAGE_GEOMETRY_BIT:
914         return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
915                                                       checkVertexPipelineStages);
916     case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT:
917         return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
918                                                                     checkVertexPipelineStages, caseDef.shaderStage);
919     case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT:
920         return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
921                                                                     checkVertexPipelineStages, caseDef.shaderStage);
922     default:
923         TCU_THROW(InternalError, "Unhandled shader stage");
924     }
925 }
926 
test(Context & context,const CaseDefinition caseDef)927 TestStatus test(Context &context, const CaseDefinition caseDef)
928 {
929     const bool isCompute = isAllComputeStages(caseDef.shaderStage);
930 #ifndef CTS_USES_VULKANSC
931     const bool isMesh = isAllMeshShadingStages(caseDef.shaderStage);
932 #else
933     const bool isMesh = false;
934 #endif // CTS_USES_VULKANSC
935     DE_ASSERT(!(isCompute && isMesh));
936 
937     if (isCompute || isMesh)
938     {
939 #ifndef CTS_USES_VULKANSC
940         const VkPhysicalDeviceSubgroupSizeControlProperties &subgroupSizeControlProperties =
941             context.getSubgroupSizeControlProperties();
942 #else
943         const VkPhysicalDeviceSubgroupSizeControlPropertiesEXT &subgroupSizeControlProperties =
944             context.getSubgroupSizeControlPropertiesEXT();
945 #endif // CTS_USES_VULKANSC
946         TestLog &log                        = context.getTestContext().getLog();
947         const subgroups::SSBOData inputData = {
948             subgroups::SSBOData::InitializeNonZero, //  InputDataInitializeType initializeType;
949             subgroups::SSBOData::LayoutStd430,      //  InputDataLayoutType layout;
950             VK_FORMAT_R32_UINT,                     //  vk::VkFormat format;
951             subgroups::maxSupportedSubgroupSize(),  //  vk::VkDeviceSize numElements;
952         };
953 
954         if (caseDef.requiredSubgroupSize == false)
955         {
956             if (isCompute)
957                 return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
958                                                   checkComputeOrMesh);
959             else
960                 return subgroups::makeMeshTest(context, VK_FORMAT_R32_UINT, &inputData, 1, nullptr, checkComputeOrMesh);
961         }
962 
963         log << TestLog::Message << "Testing required subgroup size range ["
964             << subgroupSizeControlProperties.minSubgroupSize << ", " << subgroupSizeControlProperties.maxSubgroupSize
965             << "]" << TestLog::EndMessage;
966 
967         // According to the spec, requiredSubgroupSize must be a power-of-two integer.
968         for (uint32_t size = subgroupSizeControlProperties.minSubgroupSize;
969              size <= subgroupSizeControlProperties.maxSubgroupSize; size *= 2)
970         {
971             TestStatus result(QP_TEST_RESULT_INTERNAL_ERROR, "Internal Error");
972 
973             if (isCompute)
974                 result = subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
975                                                     checkComputeOrMesh, size);
976             else
977                 result = subgroups::makeMeshTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
978                                                  checkComputeOrMesh, size);
979 
980             if (result.getCode() != QP_TEST_RESULT_PASS)
981             {
982                 log << TestLog::Message << "subgroupSize " << size << " failed" << TestLog::EndMessage;
983                 return result;
984             }
985         }
986 
987         return TestStatus::pass("OK");
988     }
989     else if (isAllGraphicsStages(caseDef.shaderStage))
990     {
991         const VkShaderStageFlags stages = subgroups::getPossibleGraphicsSubgroupStages(context, caseDef.shaderStage);
992         const subgroups::SSBOData inputData = {
993             subgroups::SSBOData::InitializeNonZero, //  InputDataInitializeType initializeType;
994             subgroups::SSBOData::LayoutStd430,      //  InputDataLayoutType layout;
995             VK_FORMAT_R32_UINT,                     //  vk::VkFormat format;
996             subgroups::maxSupportedSubgroupSize(),  //  vk::VkDeviceSize numElements;
997             subgroups::SSBOData::BindingSSBO,       //  bool isImage;
998             4u,                                     //  uint32_t binding;
999             stages,                                 //  vk::VkShaderStageFlags stages;
1000         };
1001 
1002         return subgroups::allStages(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages,
1003                                     stages);
1004     }
1005 #ifndef CTS_USES_VULKANSC
1006     else if (isAllRayTracingStages(caseDef.shaderStage))
1007     {
1008         const VkShaderStageFlags stages = subgroups::getPossibleRayTracingSubgroupStages(context, caseDef.shaderStage);
1009         const subgroups::SSBOData inputData = {
1010             subgroups::SSBOData::InitializeNonZero, //  InputDataInitializeType initializeType;
1011             subgroups::SSBOData::LayoutStd430,      //  InputDataLayoutType layout;
1012             VK_FORMAT_R32_UINT,                     //  vk::VkFormat format;
1013             subgroups::maxSupportedSubgroupSize(),  //  vk::VkDeviceSize numElements;
1014             subgroups::SSBOData::BindingSSBO,       //  bool isImage;
1015             6u,                                     //  uint32_t binding;
1016             stages,                                 //  vk::VkShaderStageFlags stages;
1017         };
1018 
1019         return subgroups::allRayTracingStages(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL,
1020                                               checkVertexPipelineStages, stages);
1021     }
1022 #endif // CTS_USES_VULKANSC
1023     else
1024         TCU_THROW(InternalError, "Unknown stage or invalid stage set");
1025 }
1026 } // namespace
1027 
1028 namespace vkt
1029 {
1030 namespace subgroups
1031 {
createSubgroupsBallotTests(TestContext & testCtx)1032 TestCaseGroup *createSubgroupsBallotTests(TestContext &testCtx)
1033 {
1034     de::MovePtr<TestCaseGroup> group(new TestCaseGroup(testCtx, "ballot"));
1035     de::MovePtr<TestCaseGroup> graphicGroup(new TestCaseGroup(testCtx, "graphics"));
1036     de::MovePtr<TestCaseGroup> computeGroup(new TestCaseGroup(testCtx, "compute"));
1037     de::MovePtr<TestCaseGroup> framebufferGroup(new TestCaseGroup(testCtx, "framebuffer"));
1038 #ifndef CTS_USES_VULKANSC
1039     de::MovePtr<TestCaseGroup> raytracingGroup(new TestCaseGroup(testCtx, "ray_tracing"));
1040     de::MovePtr<TestCaseGroup> meshGroup(new TestCaseGroup(testCtx, "mesh"));
1041     de::MovePtr<TestCaseGroup> meshGroupEXT(new TestCaseGroup(testCtx, "mesh"));
1042 #endif // CTS_USES_VULKANSC
1043     de::MovePtr<TestCaseGroup> groupEXT(new TestCaseGroup(testCtx, "ext_shader_subgroup_ballot"));
1044     de::MovePtr<TestCaseGroup> graphicGroupEXT(new TestCaseGroup(testCtx, "graphics"));
1045     de::MovePtr<TestCaseGroup> computeGroupEXT(new TestCaseGroup(testCtx, "compute"));
1046     de::MovePtr<TestCaseGroup> framebufferGroupEXT(new TestCaseGroup(testCtx, "framebuffer"));
1047     const VkShaderStageFlags fbStages[] = {
1048         VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
1049         VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
1050         VK_SHADER_STAGE_GEOMETRY_BIT,
1051         VK_SHADER_STAGE_VERTEX_BIT,
1052     };
1053 #ifndef CTS_USES_VULKANSC
1054     const VkShaderStageFlags meshStages[] = {
1055         VK_SHADER_STAGE_MESH_BIT_EXT,
1056         VK_SHADER_STAGE_TASK_BIT_EXT,
1057     };
1058 #endif // CTS_USES_VULKANSC
1059     const bool boolValues[] = {false, true};
1060 
1061     for (size_t groupSizeNdx = 0; groupSizeNdx < DE_LENGTH_OF_ARRAY(boolValues); ++groupSizeNdx)
1062     {
1063         const bool requiredSubgroupSize = boolValues[groupSizeNdx];
1064         const string testNameSuffix     = requiredSubgroupSize ? "_requiredsubgroupsize" : "";
1065 
1066         for (size_t extNdx = 0; extNdx < DE_LENGTH_OF_ARRAY(boolValues); ++extNdx)
1067         {
1068             const bool extShaderSubGroupBallotTests = boolValues[extNdx];
1069             TestCaseGroup *testGroup = extShaderSubGroupBallotTests ? computeGroupEXT.get() : computeGroup.get();
1070             {
1071                 const CaseDefinition caseDef = {
1072                     VK_SHADER_STAGE_COMPUTE_BIT,   //  VkShaderStageFlags shaderStage;
1073                     de::SharedPtr<bool>(new bool), //  de::SharedPtr<bool> geometryPointSizeSupported;
1074                     extShaderSubGroupBallotTests,  //  bool extShaderSubGroupBallotTests;
1075                     requiredSubgroupSize,          //  bool requiredSubgroupSize;
1076                 };
1077                 const string testName = getShaderStageName(caseDef.shaderStage) + testNameSuffix;
1078 
1079                 addFunctionCaseWithPrograms(testGroup, testName, supportedCheck, initPrograms, test, caseDef);
1080             }
1081         }
1082     }
1083 
1084 #ifndef CTS_USES_VULKANSC
1085     for (size_t groupSizeNdx = 0; groupSizeNdx < DE_LENGTH_OF_ARRAY(boolValues); ++groupSizeNdx)
1086     {
1087         const bool requiredSubgroupSize = boolValues[groupSizeNdx];
1088         const string testNameSuffix     = requiredSubgroupSize ? "_requiredsubgroupsize" : "";
1089 
1090         for (size_t extNdx = 0; extNdx < DE_LENGTH_OF_ARRAY(boolValues); ++extNdx)
1091         {
1092             const bool extShaderSubGroupBallotTests = boolValues[extNdx];
1093             TestCaseGroup *testGroup = extShaderSubGroupBallotTests ? meshGroupEXT.get() : meshGroup.get();
1094 
1095             for (const auto &stage : meshStages)
1096             {
1097                 const CaseDefinition caseDef = {
1098                     stage,                         //  VkShaderStageFlags shaderStage;
1099                     de::SharedPtr<bool>(new bool), //  de::SharedPtr<bool> geometryPointSizeSupported;
1100                     extShaderSubGroupBallotTests,  //  bool extShaderSubGroupBallotTests;
1101                     requiredSubgroupSize,          //  bool requiredSubgroupSize;
1102                 };
1103                 const string testName =
1104                     getShaderStageName(caseDef.shaderStage) + testNameSuffix + "_" + getShaderStageName(stage);
1105 
1106                 addFunctionCaseWithPrograms(testGroup, testName, supportedCheck, initPrograms, test, caseDef);
1107             }
1108         }
1109     }
1110 #endif // CTS_USES_VULKANSC
1111 
1112     for (size_t extNdx = 0; extNdx < DE_LENGTH_OF_ARRAY(boolValues); ++extNdx)
1113     {
1114         const bool extShaderSubGroupBallotTests = boolValues[extNdx];
1115         TestCaseGroup *testGroup     = extShaderSubGroupBallotTests ? graphicGroupEXT.get() : graphicGroup.get();
1116         const CaseDefinition caseDef = {
1117             VK_SHADER_STAGE_ALL_GRAPHICS,  //  VkShaderStageFlags shaderStage;
1118             de::SharedPtr<bool>(new bool), //  de::SharedPtr<bool> geometryPointSizeSupported;
1119             extShaderSubGroupBallotTests,  //  bool extShaderSubGroupBallotTests;
1120             false,                         //  bool requiredSubgroupSize;
1121         };
1122 
1123         addFunctionCaseWithPrograms(testGroup, "graphic", supportedCheck, initPrograms, test, caseDef);
1124     }
1125 
1126 #ifndef CTS_USES_VULKANSC
1127     {
1128         const CaseDefinition caseDef = {
1129             SHADER_STAGE_ALL_RAY_TRACING,  //  VkShaderStageFlags shaderStage;
1130             de::SharedPtr<bool>(new bool), //  de::SharedPtr<bool> geometryPointSizeSupported;
1131             false,                         //  bool extShaderSubGroupBallotTests;
1132             false,                         //  bool requiredSubgroupSize;
1133         };
1134 
1135         addFunctionCaseWithPrograms(raytracingGroup.get(), "test", supportedCheck, initPrograms, test, caseDef);
1136     }
1137 #endif // CTS_USES_VULKANSC
1138 
1139     for (size_t extNdx = 0; extNdx < DE_LENGTH_OF_ARRAY(boolValues); ++extNdx)
1140     {
1141         const bool extShaderSubGroupBallotTests = boolValues[extNdx];
1142         TestCaseGroup *testGroup = extShaderSubGroupBallotTests ? framebufferGroupEXT.get() : framebufferGroup.get();
1143 
1144         for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(fbStages); ++stageIndex)
1145         {
1146             const CaseDefinition caseDef = {
1147                 fbStages[stageIndex],          //  VkShaderStageFlags shaderStage;
1148                 de::SharedPtr<bool>(new bool), //  de::SharedPtr<bool> geometryPointSizeSupported;
1149                 extShaderSubGroupBallotTests,  //  bool extShaderSubGroupBallotTests;
1150                 false                          //  bool requiredSubgroupSize;
1151             };
1152 
1153             addFunctionCaseWithPrograms(testGroup, getShaderStageName(caseDef.shaderStage), supportedCheck,
1154                                         initFrameBufferPrograms, noSSBOtest, caseDef);
1155         }
1156     }
1157 
1158     groupEXT->addChild(graphicGroupEXT.release());
1159     groupEXT->addChild(computeGroupEXT.release());
1160     groupEXT->addChild(framebufferGroupEXT.release());
1161 
1162     group->addChild(graphicGroup.release());
1163     group->addChild(computeGroup.release());
1164     group->addChild(framebufferGroup.release());
1165 #ifndef CTS_USES_VULKANSC
1166     group->addChild(raytracingGroup.release());
1167     group->addChild(meshGroup.release());
1168     groupEXT->addChild(meshGroupEXT.release());
1169 #endif // CTS_USES_VULKANSC
1170     group->addChild(groupEXT.release());
1171 
1172     return group.release();
1173 }
1174 
1175 } // namespace subgroups
1176 } // namespace vkt
1177