1 /*------------------------------------------------------------------------
2 * OpenGL Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2017-2019 The Khronos Group Inc.
6 * Copyright (c) 2017 Codeplay Software Ltd.
7 * Copyright (c) 2019 NVIDIA Corporation.
8 *
9 * Licensed under the Apache License, Version 2.0 (the "License");
10 * you may not use this file except in compliance with the License.
11 * You may obtain a copy of the License at
12 *
13 * http://www.apache.org/licenses/LICENSE-2.0
14 *
15 * Unless required by applicable law or agreed to in writing, software
16 * distributed under the License is distributed on an "AS IS" BASIS,
17 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 * See the License for the specific language governing permissions and
19 * limitations under the License.
20 *
21 */ /*!
22 * \file
23 * \brief Subgroups Tests
24 */ /*--------------------------------------------------------------------*/
25
26 #include "glcSubgroupsArithmeticTests.hpp"
27 #include "glcSubgroupsTestsUtils.hpp"
28
29 #include <string>
30 #include <vector>
31
32 using namespace tcu;
33 using namespace std;
34
35 namespace glc
36 {
37 namespace subgroups
38 {
39 namespace
40 {
41 enum OpType
42 {
43 OPTYPE_ADD = 0,
44 OPTYPE_MUL,
45 OPTYPE_MIN,
46 OPTYPE_MAX,
47 OPTYPE_AND,
48 OPTYPE_OR,
49 OPTYPE_XOR,
50 OPTYPE_INCLUSIVE_ADD,
51 OPTYPE_INCLUSIVE_MUL,
52 OPTYPE_INCLUSIVE_MIN,
53 OPTYPE_INCLUSIVE_MAX,
54 OPTYPE_INCLUSIVE_AND,
55 OPTYPE_INCLUSIVE_OR,
56 OPTYPE_INCLUSIVE_XOR,
57 OPTYPE_EXCLUSIVE_ADD,
58 OPTYPE_EXCLUSIVE_MUL,
59 OPTYPE_EXCLUSIVE_MIN,
60 OPTYPE_EXCLUSIVE_MAX,
61 OPTYPE_EXCLUSIVE_AND,
62 OPTYPE_EXCLUSIVE_OR,
63 OPTYPE_EXCLUSIVE_XOR,
64 OPTYPE_LAST
65 };
66
checkVertexPipelineStages(std::vector<const void * > datas,uint32_t width,uint32_t)67 static bool checkVertexPipelineStages(std::vector<const void *> datas, uint32_t width, uint32_t)
68 {
69 return glc::subgroups::check(datas, width, 0x3);
70 }
71
checkComputeStage(std::vector<const void * > datas,const uint32_t numWorkgroups[3],const uint32_t localSize[3],uint32_t)72 static bool checkComputeStage(std::vector<const void *> datas, const uint32_t numWorkgroups[3],
73 const uint32_t localSize[3], uint32_t)
74 {
75 return glc::subgroups::checkCompute(datas, numWorkgroups, localSize, 0x3);
76 }
77
getOpTypeName(int opType)78 std::string getOpTypeName(int opType)
79 {
80 switch (opType)
81 {
82 default:
83 DE_FATAL("Unsupported op type");
84 return "";
85 case OPTYPE_ADD:
86 return "subgroupAdd";
87 case OPTYPE_MUL:
88 return "subgroupMul";
89 case OPTYPE_MIN:
90 return "subgroupMin";
91 case OPTYPE_MAX:
92 return "subgroupMax";
93 case OPTYPE_AND:
94 return "subgroupAnd";
95 case OPTYPE_OR:
96 return "subgroupOr";
97 case OPTYPE_XOR:
98 return "subgroupXor";
99 case OPTYPE_INCLUSIVE_ADD:
100 return "subgroupInclusiveAdd";
101 case OPTYPE_INCLUSIVE_MUL:
102 return "subgroupInclusiveMul";
103 case OPTYPE_INCLUSIVE_MIN:
104 return "subgroupInclusiveMin";
105 case OPTYPE_INCLUSIVE_MAX:
106 return "subgroupInclusiveMax";
107 case OPTYPE_INCLUSIVE_AND:
108 return "subgroupInclusiveAnd";
109 case OPTYPE_INCLUSIVE_OR:
110 return "subgroupInclusiveOr";
111 case OPTYPE_INCLUSIVE_XOR:
112 return "subgroupInclusiveXor";
113 case OPTYPE_EXCLUSIVE_ADD:
114 return "subgroupExclusiveAdd";
115 case OPTYPE_EXCLUSIVE_MUL:
116 return "subgroupExclusiveMul";
117 case OPTYPE_EXCLUSIVE_MIN:
118 return "subgroupExclusiveMin";
119 case OPTYPE_EXCLUSIVE_MAX:
120 return "subgroupExclusiveMax";
121 case OPTYPE_EXCLUSIVE_AND:
122 return "subgroupExclusiveAnd";
123 case OPTYPE_EXCLUSIVE_OR:
124 return "subgroupExclusiveOr";
125 case OPTYPE_EXCLUSIVE_XOR:
126 return "subgroupExclusiveXor";
127 }
128 }
129
getOpTypeOperation(int opType,Format format,std::string lhs,std::string rhs)130 std::string getOpTypeOperation(int opType, Format format, std::string lhs, std::string rhs)
131 {
132 switch (opType)
133 {
134 default:
135 DE_FATAL("Unsupported op type");
136 return "";
137 case OPTYPE_ADD:
138 case OPTYPE_INCLUSIVE_ADD:
139 case OPTYPE_EXCLUSIVE_ADD:
140 return lhs + " + " + rhs;
141 case OPTYPE_MUL:
142 case OPTYPE_INCLUSIVE_MUL:
143 case OPTYPE_EXCLUSIVE_MUL:
144 return lhs + " * " + rhs;
145 case OPTYPE_MIN:
146 case OPTYPE_INCLUSIVE_MIN:
147 case OPTYPE_EXCLUSIVE_MIN:
148 switch (format)
149 {
150 default:
151 return "min(" + lhs + ", " + rhs + ")";
152 case FORMAT_R32_SFLOAT:
153 case FORMAT_R64_SFLOAT:
154 return "(isnan(" + lhs + ") ? " + rhs + " : (isnan(" + rhs + ") ? " + lhs + " : min(" + lhs + ", " + rhs +
155 ")))";
156 case FORMAT_R32G32_SFLOAT:
157 case FORMAT_R32G32B32_SFLOAT:
158 case FORMAT_R32G32B32A32_SFLOAT:
159 case FORMAT_R64G64_SFLOAT:
160 case FORMAT_R64G64B64_SFLOAT:
161 case FORMAT_R64G64B64A64_SFLOAT:
162 return "mix(mix(min(" + lhs + ", " + rhs + "), " + lhs + ", isnan(" + rhs + ")), " + rhs + ", isnan(" +
163 lhs + "))";
164 }
165 case OPTYPE_MAX:
166 case OPTYPE_INCLUSIVE_MAX:
167 case OPTYPE_EXCLUSIVE_MAX:
168 switch (format)
169 {
170 default:
171 return "max(" + lhs + ", " + rhs + ")";
172 case FORMAT_R32_SFLOAT:
173 case FORMAT_R64_SFLOAT:
174 return "(isnan(" + lhs + ") ? " + rhs + " : (isnan(" + rhs + ") ? " + lhs + " : max(" + lhs + ", " + rhs +
175 ")))";
176 case FORMAT_R32G32_SFLOAT:
177 case FORMAT_R32G32B32_SFLOAT:
178 case FORMAT_R32G32B32A32_SFLOAT:
179 case FORMAT_R64G64_SFLOAT:
180 case FORMAT_R64G64B64_SFLOAT:
181 case FORMAT_R64G64B64A64_SFLOAT:
182 return "mix(mix(max(" + lhs + ", " + rhs + "), " + lhs + ", isnan(" + rhs + ")), " + rhs + ", isnan(" +
183 lhs + "))";
184 }
185 case OPTYPE_AND:
186 case OPTYPE_INCLUSIVE_AND:
187 case OPTYPE_EXCLUSIVE_AND:
188 switch (format)
189 {
190 default:
191 return lhs + " & " + rhs;
192 case FORMAT_R32_BOOL:
193 return lhs + " && " + rhs;
194 case FORMAT_R32G32_BOOL:
195 return "bvec2(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y)";
196 case FORMAT_R32G32B32_BOOL:
197 return "bvec3(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y, " + lhs + ".z && " + rhs +
198 ".z)";
199 case FORMAT_R32G32B32A32_BOOL:
200 return "bvec4(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y, " + lhs + ".z && " + rhs +
201 ".z, " + lhs + ".w && " + rhs + ".w)";
202 }
203 case OPTYPE_OR:
204 case OPTYPE_INCLUSIVE_OR:
205 case OPTYPE_EXCLUSIVE_OR:
206 switch (format)
207 {
208 default:
209 return lhs + " | " + rhs;
210 case FORMAT_R32_BOOL:
211 return lhs + " || " + rhs;
212 case FORMAT_R32G32_BOOL:
213 return "bvec2(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y)";
214 case FORMAT_R32G32B32_BOOL:
215 return "bvec3(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y, " + lhs + ".z || " + rhs +
216 ".z)";
217 case FORMAT_R32G32B32A32_BOOL:
218 return "bvec4(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y, " + lhs + ".z || " + rhs +
219 ".z, " + lhs + ".w || " + rhs + ".w)";
220 }
221 case OPTYPE_XOR:
222 case OPTYPE_INCLUSIVE_XOR:
223 case OPTYPE_EXCLUSIVE_XOR:
224 switch (format)
225 {
226 default:
227 return lhs + " ^ " + rhs;
228 case FORMAT_R32_BOOL:
229 return lhs + " ^^ " + rhs;
230 case FORMAT_R32G32_BOOL:
231 return "bvec2(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y)";
232 case FORMAT_R32G32B32_BOOL:
233 return "bvec3(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y, " + lhs + ".z ^^ " + rhs +
234 ".z)";
235 case FORMAT_R32G32B32A32_BOOL:
236 return "bvec4(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y, " + lhs + ".z ^^ " + rhs +
237 ".z, " + lhs + ".w ^^ " + rhs + ".w)";
238 }
239 }
240 }
241
getIdentity(int opType,Format format)242 std::string getIdentity(int opType, Format format)
243 {
244 bool isFloat = false;
245 bool isInt = false;
246 bool isUnsigned = false;
247
248 switch (format)
249 {
250 default:
251 DE_FATAL("Unhandled format!");
252 break;
253 case FORMAT_R32_SINT:
254 case FORMAT_R32G32_SINT:
255 case FORMAT_R32G32B32_SINT:
256 case FORMAT_R32G32B32A32_SINT:
257 isInt = true;
258 break;
259 case FORMAT_R32_UINT:
260 case FORMAT_R32G32_UINT:
261 case FORMAT_R32G32B32_UINT:
262 case FORMAT_R32G32B32A32_UINT:
263 isUnsigned = true;
264 break;
265 case FORMAT_R32_SFLOAT:
266 case FORMAT_R32G32_SFLOAT:
267 case FORMAT_R32G32B32_SFLOAT:
268 case FORMAT_R32G32B32A32_SFLOAT:
269 case FORMAT_R64_SFLOAT:
270 case FORMAT_R64G64_SFLOAT:
271 case FORMAT_R64G64B64_SFLOAT:
272 case FORMAT_R64G64B64A64_SFLOAT:
273 isFloat = true;
274 break;
275 case FORMAT_R32_BOOL:
276 case FORMAT_R32G32_BOOL:
277 case FORMAT_R32G32B32_BOOL:
278 case FORMAT_R32G32B32A32_BOOL:
279 break; // bool types are not anything
280 }
281
282 switch (opType)
283 {
284 default:
285 DE_FATAL("Unsupported op type");
286 return "";
287 case OPTYPE_ADD:
288 case OPTYPE_INCLUSIVE_ADD:
289 case OPTYPE_EXCLUSIVE_ADD:
290 return subgroups::getFormatNameForGLSL(format) + "(0)";
291 case OPTYPE_MUL:
292 case OPTYPE_INCLUSIVE_MUL:
293 case OPTYPE_EXCLUSIVE_MUL:
294 return subgroups::getFormatNameForGLSL(format) + "(1)";
295 case OPTYPE_MIN:
296 case OPTYPE_INCLUSIVE_MIN:
297 case OPTYPE_EXCLUSIVE_MIN:
298 if (isFloat)
299 {
300 return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0x7f800000))";
301 }
302 else if (isInt)
303 {
304 return subgroups::getFormatNameForGLSL(format) + "(0x7fffffff)";
305 }
306 else if (isUnsigned)
307 {
308 return subgroups::getFormatNameForGLSL(format) + "(0xffffffffu)";
309 }
310 else
311 {
312 DE_FATAL("Unhandled case");
313 return "";
314 }
315 case OPTYPE_MAX:
316 case OPTYPE_INCLUSIVE_MAX:
317 case OPTYPE_EXCLUSIVE_MAX:
318 if (isFloat)
319 {
320 return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0xff800000))";
321 }
322 else if (isInt)
323 {
324 return subgroups::getFormatNameForGLSL(format) + "(0x80000000)";
325 }
326 else if (isUnsigned)
327 {
328 return subgroups::getFormatNameForGLSL(format) + "(0u)";
329 }
330 else
331 {
332 DE_FATAL("Unhandled case");
333 return "";
334 }
335 case OPTYPE_AND:
336 case OPTYPE_INCLUSIVE_AND:
337 case OPTYPE_EXCLUSIVE_AND:
338 return subgroups::getFormatNameForGLSL(format) + "(~0)";
339 case OPTYPE_OR:
340 case OPTYPE_INCLUSIVE_OR:
341 case OPTYPE_EXCLUSIVE_OR:
342 return subgroups::getFormatNameForGLSL(format) + "(0)";
343 case OPTYPE_XOR:
344 case OPTYPE_INCLUSIVE_XOR:
345 case OPTYPE_EXCLUSIVE_XOR:
346 return subgroups::getFormatNameForGLSL(format) + "(0)";
347 }
348 }
349
getCompare(int opType,Format format,std::string lhs,std::string rhs)350 std::string getCompare(int opType, Format format, std::string lhs, std::string rhs)
351 {
352 std::string formatName = subgroups::getFormatNameForGLSL(format);
353 switch (format)
354 {
355 default:
356 return "all(equal(" + lhs + ", " + rhs + "))";
357 case FORMAT_R32_BOOL:
358 case FORMAT_R32_UINT:
359 case FORMAT_R32_SINT:
360 return "(" + lhs + " == " + rhs + ")";
361 case FORMAT_R32_SFLOAT:
362 case FORMAT_R64_SFLOAT:
363 switch (opType)
364 {
365 default:
366 return "(abs(" + lhs + " - " + rhs + ") < 0.00001)";
367 case OPTYPE_MIN:
368 case OPTYPE_INCLUSIVE_MIN:
369 case OPTYPE_EXCLUSIVE_MIN:
370 case OPTYPE_MAX:
371 case OPTYPE_INCLUSIVE_MAX:
372 case OPTYPE_EXCLUSIVE_MAX:
373 return "(" + lhs + " == " + rhs + ")";
374 }
375 case FORMAT_R32G32_SFLOAT:
376 case FORMAT_R32G32B32_SFLOAT:
377 case FORMAT_R32G32B32A32_SFLOAT:
378 case FORMAT_R64G64_SFLOAT:
379 case FORMAT_R64G64B64_SFLOAT:
380 case FORMAT_R64G64B64A64_SFLOAT:
381 switch (opType)
382 {
383 default:
384 return "all(lessThan(abs(" + lhs + " - " + rhs + "), " + formatName + "(0.00001)))";
385 case OPTYPE_MIN:
386 case OPTYPE_INCLUSIVE_MIN:
387 case OPTYPE_EXCLUSIVE_MIN:
388 case OPTYPE_MAX:
389 case OPTYPE_INCLUSIVE_MAX:
390 case OPTYPE_EXCLUSIVE_MAX:
391 return "all(equal(" + lhs + ", " + rhs + "))";
392 }
393 }
394 }
395
396 struct CaseDefinition
397 {
398 int opType;
399 ShaderStageFlags shaderStage;
400 Format format;
401 };
402
initFrameBufferPrograms(SourceCollections & programCollection,CaseDefinition caseDef)403 void initFrameBufferPrograms(SourceCollections &programCollection, CaseDefinition caseDef)
404 {
405 std::string indexVars;
406 std::ostringstream bdy;
407
408 subgroups::setFragmentShaderFrameBuffer(programCollection);
409
410 if (SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
411 subgroups::setVertexShaderFrameBuffer(programCollection);
412
413 switch (caseDef.opType)
414 {
415 default:
416 indexVars = " uint start = 0u, end = gl_SubgroupSize;\n";
417 break;
418 case OPTYPE_INCLUSIVE_ADD:
419 case OPTYPE_INCLUSIVE_MUL:
420 case OPTYPE_INCLUSIVE_MIN:
421 case OPTYPE_INCLUSIVE_MAX:
422 case OPTYPE_INCLUSIVE_AND:
423 case OPTYPE_INCLUSIVE_OR:
424 case OPTYPE_INCLUSIVE_XOR:
425 indexVars = " uint start = 0u, end = gl_SubgroupInvocationID + 1u;\n";
426 break;
427 case OPTYPE_EXCLUSIVE_ADD:
428 case OPTYPE_EXCLUSIVE_MUL:
429 case OPTYPE_EXCLUSIVE_MIN:
430 case OPTYPE_EXCLUSIVE_MAX:
431 case OPTYPE_EXCLUSIVE_AND:
432 case OPTYPE_EXCLUSIVE_OR:
433 case OPTYPE_EXCLUSIVE_XOR:
434 indexVars = " uint start = 0u, end = gl_SubgroupInvocationID;\n";
435 break;
436 }
437
438 bdy << indexVars << " " << subgroups::getFormatNameForGLSL(caseDef.format)
439 << " ref = " << getIdentity(caseDef.opType, caseDef.format) << ";\n"
440 << " uint tempResult = 0u;\n"
441 << " for (uint index = start; index < end; index++)\n"
442 << " {\n"
443 << " if (subgroupBallotBitExtract(mask, index))\n"
444 << " {\n"
445 << " ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n"
446 << " }\n"
447 << " }\n"
448 << " tempResult = "
449 << getCompare(caseDef.opType, caseDef.format, "ref",
450 getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])")
451 << " ? 0x1u : 0u;\n"
452 << " if (1u == (gl_SubgroupInvocationID % 2u))\n"
453 << " {\n"
454 << " mask = subgroupBallot(true);\n"
455 << " ref = " << getIdentity(caseDef.opType, caseDef.format) << ";\n"
456 << " for (uint index = start; index < end; index++)\n"
457 << " {\n"
458 << " if (subgroupBallotBitExtract(mask, index))\n"
459 << " {\n"
460 << " ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n"
461 << " }\n"
462 << " }\n"
463 << " tempResult |= "
464 << getCompare(caseDef.opType, caseDef.format, "ref",
465 getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])")
466 << " ? 0x2u : 0u;\n"
467 << " }\n"
468 << " else\n"
469 << " {\n"
470 << " tempResult |= 0x2u;\n"
471 << " }\n";
472
473 if (SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
474 {
475 std::ostringstream vertexSrc;
476 vertexSrc << "${VERSION_DECL}\n"
477 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
478 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
479 << "layout(location = 0) in highp vec4 in_position;\n"
480 << "layout(location = 0) out float out_color;\n"
481 << "layout(binding = 0, std140) uniform Buffer0\n"
482 << "{\n"
483 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data["
484 << subgroups::maxSupportedSubgroupSize() << "];\n"
485 << "};\n"
486 << "\n"
487 << "void main (void)\n"
488 << "{\n"
489 << " uvec4 mask = subgroupBallot(true);\n"
490 << bdy.str() << " out_color = float(tempResult);\n"
491 << " gl_Position = in_position;\n"
492 << " gl_PointSize = 1.0f;\n"
493 << "}\n";
494 programCollection.add("vert") << glu::VertexSource(vertexSrc.str());
495 }
496 else if (SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
497 {
498 std::ostringstream geometry;
499
500 geometry << "${VERSION_DECL}\n"
501 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
502 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
503 << "layout(points) in;\n"
504 << "layout(points, max_vertices = 1) out;\n"
505 << "layout(location = 0) out float out_color;\n"
506 << "layout(binding = 0, std140) uniform Buffer0\n"
507 << "{\n"
508 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data["
509 << subgroups::maxSupportedSubgroupSize() << "];\n"
510 << "};\n"
511 << "\n"
512 << "void main (void)\n"
513 << "{\n"
514 << " uvec4 mask = subgroupBallot(true);\n"
515 << bdy.str() << " out_color = float(tempResult);\n"
516 << " gl_Position = gl_in[0].gl_Position;\n"
517 << " EmitVertex();\n"
518 << " EndPrimitive();\n"
519 << "}\n";
520
521 programCollection.add("geometry") << glu::GeometrySource(geometry.str());
522 }
523 else if (SHADER_STAGE_TESS_CONTROL_BIT == caseDef.shaderStage)
524 {
525 std::ostringstream controlSource;
526 controlSource << "${VERSION_DECL}\n"
527 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
528 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
529 << "layout(vertices = 2) out;\n"
530 << "layout(location = 0) out float out_color[];\n"
531 << "layout(binding = 0, std140) uniform Buffer0\n"
532 << "{\n"
533 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data["
534 << subgroups::maxSupportedSubgroupSize() << "];\n"
535 << "};\n"
536 << "\n"
537 << "void main (void)\n"
538 << "{\n"
539 << " if (gl_InvocationID == 0)\n"
540 << " {\n"
541 << " gl_TessLevelOuter[0] = 1.0f;\n"
542 << " gl_TessLevelOuter[1] = 1.0f;\n"
543 << " }\n"
544 << " uvec4 mask = subgroupBallot(true);\n"
545 << bdy.str() << " out_color[gl_InvocationID] = float(tempResult);"
546 << " gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
547 << "}\n";
548
549 programCollection.add("tesc") << glu::TessellationControlSource(controlSource.str());
550 subgroups::setTesEvalShaderFrameBuffer(programCollection);
551 }
552 else if (SHADER_STAGE_TESS_EVALUATION_BIT == caseDef.shaderStage)
553 {
554
555 std::ostringstream evaluationSource;
556 evaluationSource << "${VERSION_DECL}\n"
557 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
558 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
559 << "layout(isolines, equal_spacing, ccw ) in;\n"
560 << "layout(location = 0) out float out_color;\n"
561 << "layout(binding = 0, std140) uniform Buffer0\n"
562 << "{\n"
563 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data["
564 << subgroups::maxSupportedSubgroupSize() << "];\n"
565 << "};\n"
566 << "\n"
567 << "void main (void)\n"
568 << "{\n"
569 << " uvec4 mask = subgroupBallot(true);\n"
570 << bdy.str() << " out_color = float(tempResult);\n"
571 << " gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
572 << "}\n";
573
574 subgroups::setTesCtrlShaderFrameBuffer(programCollection);
575 programCollection.add("tese") << glu::TessellationEvaluationSource(evaluationSource.str());
576 }
577 else
578 {
579 DE_FATAL("Unsupported shader stage");
580 }
581 }
582
initPrograms(SourceCollections & programCollection,CaseDefinition caseDef)583 void initPrograms(SourceCollections &programCollection, CaseDefinition caseDef)
584 {
585 std::string indexVars;
586 switch (caseDef.opType)
587 {
588 default:
589 indexVars = " uint start = 0u, end = gl_SubgroupSize;\n";
590 break;
591 case OPTYPE_INCLUSIVE_ADD:
592 case OPTYPE_INCLUSIVE_MUL:
593 case OPTYPE_INCLUSIVE_MIN:
594 case OPTYPE_INCLUSIVE_MAX:
595 case OPTYPE_INCLUSIVE_AND:
596 case OPTYPE_INCLUSIVE_OR:
597 case OPTYPE_INCLUSIVE_XOR:
598 indexVars = " uint start = 0u, end = gl_SubgroupInvocationID + 1u;\n";
599 break;
600 case OPTYPE_EXCLUSIVE_ADD:
601 case OPTYPE_EXCLUSIVE_MUL:
602 case OPTYPE_EXCLUSIVE_MIN:
603 case OPTYPE_EXCLUSIVE_MAX:
604 case OPTYPE_EXCLUSIVE_AND:
605 case OPTYPE_EXCLUSIVE_OR:
606 case OPTYPE_EXCLUSIVE_XOR:
607 indexVars = " uint start = 0u, end = gl_SubgroupInvocationID;\n";
608 break;
609 }
610
611 const string bdy = indexVars + " " + subgroups::getFormatNameForGLSL(caseDef.format) +
612 " ref = " + getIdentity(caseDef.opType, caseDef.format) +
613 ";\n"
614 " uint tempResult = 0u;\n"
615 " for (uint index = start; index < end; index++)\n"
616 " {\n"
617 " if (subgroupBallotBitExtract(mask, index))\n"
618 " {\n"
619 " ref = " +
620 getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") +
621 ";\n"
622 " }\n"
623 " }\n"
624 " tempResult = " +
625 getCompare(caseDef.opType, caseDef.format, "ref",
626 getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])") +
627 " ? 0x1u : 0u;\n"
628 " if (1u == (gl_SubgroupInvocationID % 2u))\n"
629 " {\n"
630 " mask = subgroupBallot(true);\n"
631 " ref = " +
632 getIdentity(caseDef.opType, caseDef.format) +
633 ";\n"
634 " for (uint index = start; index < end; index++)\n"
635 " {\n"
636 " if (subgroupBallotBitExtract(mask, index))\n"
637 " {\n"
638 " ref = " +
639 getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") +
640 ";\n"
641 " }\n"
642 " }\n"
643 " tempResult |= " +
644 getCompare(caseDef.opType, caseDef.format, "ref",
645 getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])") +
646 " ? 0x2u : 0u;\n"
647 " }\n"
648 " else\n"
649 " {\n"
650 " tempResult |= 0x2u;\n"
651 " }\n";
652
653 if (SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
654 {
655 std::ostringstream src;
656
657 src << "${VERSION_DECL}\n"
658 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
659 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
660 << "layout (${LOCAL_SIZE_X}, ${LOCAL_SIZE_Y}, ${LOCAL_SIZE_Z}) in;\n"
661 << "layout(binding = 0, std430) buffer Buffer0\n"
662 << "{\n"
663 << " uint result[];\n"
664 << "};\n"
665 << "layout(binding = 1, std430) buffer Buffer1\n"
666 << "{\n"
667 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
668 << "};\n"
669 << "\n"
670 << "void main (void)\n"
671 << "{\n"
672 << " uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
673 << " highp uint offset = globalSize.x * ((globalSize.y * "
674 "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
675 "gl_GlobalInvocationID.x;\n"
676 << " uvec4 mask = subgroupBallot(true);\n"
677 << bdy << " result[offset] = tempResult;\n"
678 << "}\n";
679
680 programCollection.add("comp") << glu::ComputeSource(src.str());
681 }
682 else
683 {
684 {
685 const std::string vertex =
686 "${VERSION_DECL}\n"
687 "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
688 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
689 "layout(binding = 0, std430) buffer Buffer0\n"
690 "{\n"
691 " uint result[];\n"
692 "} b0;\n"
693 "layout(binding = 4, std430) readonly buffer Buffer4\n"
694 "{\n"
695 " " +
696 subgroups::getFormatNameForGLSL(caseDef.format) +
697 " data[];\n"
698 "};\n"
699 "\n"
700 "void main (void)\n"
701 "{\n"
702 " uvec4 mask = subgroupBallot(true);\n" +
703 bdy +
704 " b0.result[gl_VertexID] = tempResult;\n"
705 " float pixelSize = 2.0f/1024.0f;\n"
706 " float pixelPosition = pixelSize/2.0f - 1.0f;\n"
707 " gl_Position = vec4(float(gl_VertexID) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
708 " gl_PointSize = 1.0f;\n"
709 "}\n";
710 programCollection.add("vert") << glu::VertexSource(vertex);
711 }
712
713 {
714 const std::string tesc = "${VERSION_DECL}\n"
715 "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
716 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
717 "layout(vertices=1) out;\n"
718 "layout(binding = 1, std430) buffer Buffer1\n"
719 "{\n"
720 " uint result[];\n"
721 "} b1;\n"
722 "layout(binding = 4, std430) readonly buffer Buffer4\n"
723 "{\n"
724 " " +
725 subgroups::getFormatNameForGLSL(caseDef.format) +
726 " data[];\n"
727 "};\n"
728 "\n"
729 "void main (void)\n"
730 "{\n"
731 " uvec4 mask = subgroupBallot(true);\n" +
732 bdy +
733 " b1.result[gl_PrimitiveID] = tempResult;\n"
734 " if (gl_InvocationID == 0)\n"
735 " {\n"
736 " gl_TessLevelOuter[0] = 1.0f;\n"
737 " gl_TessLevelOuter[1] = 1.0f;\n"
738 " }\n"
739 " gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
740 "}\n";
741 programCollection.add("tesc") << glu::TessellationControlSource(tesc);
742 }
743
744 {
745 const std::string tese = "${VERSION_DECL}\n"
746 "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
747 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
748 "layout(isolines) in;\n"
749 "layout(binding = 2, std430) buffer Buffer2\n"
750 "{\n"
751 " uint result[];\n"
752 "} b2;\n"
753 "layout(binding = 4, std430) readonly buffer Buffer4\n"
754 "{\n"
755 " " +
756 subgroups::getFormatNameForGLSL(caseDef.format) +
757 " data[];\n"
758 "};\n"
759 "\n"
760 "void main (void)\n"
761 "{\n"
762 " uvec4 mask = subgroupBallot(true);\n" +
763 bdy +
764 " b2.result[gl_PrimitiveID * 2 + int(gl_TessCoord.x + 0.5)] = tempResult;\n"
765 " float pixelSize = 2.0f/1024.0f;\n"
766 " gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
767 "}\n";
768 programCollection.add("tese") << glu::TessellationEvaluationSource(tese);
769 }
770
771 {
772 const std::string geometry =
773 // version added by addGeometryShadersFromTemplate
774 "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
775 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
776 "layout(${TOPOLOGY}) in;\n"
777 "layout(points, max_vertices = 1) out;\n"
778 "layout(binding = 3, std430) buffer Buffer3\n"
779 "{\n"
780 " uint result[];\n"
781 "} b3;\n"
782 "layout(binding = 4, std430) readonly buffer Buffer4\n"
783 "{\n"
784 " " +
785 subgroups::getFormatNameForGLSL(caseDef.format) +
786 " data[];\n"
787 "};\n"
788 "\n"
789 "void main (void)\n"
790 "{\n"
791 " uvec4 mask = subgroupBallot(true);\n" +
792 bdy +
793 " b3.result[gl_PrimitiveIDIn] = tempResult;\n"
794 " gl_Position = gl_in[0].gl_Position;\n"
795 " EmitVertex();\n"
796 " EndPrimitive();\n"
797 "}\n";
798 subgroups::addGeometryShadersFromTemplate(geometry, programCollection);
799 }
800
801 {
802 const std::string fragment = "${VERSION_DECL}\n"
803 "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
804 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
805 "precision highp int;\n"
806 "precision highp float;\n"
807 "layout(location = 0) out uint result;\n"
808 "layout(binding = 4, std430) readonly buffer Buffer4\n"
809 "{\n"
810 " " +
811 subgroups::getFormatNameForGLSL(caseDef.format) +
812 " data[];\n"
813 "};\n"
814 "void main (void)\n"
815 "{\n"
816 " uvec4 mask = subgroupBallot(true);\n" +
817 bdy +
818 " result = tempResult;\n"
819 "}\n";
820 programCollection.add("fragment") << glu::FragmentSource(fragment);
821 }
822 subgroups::addNoSubgroupShader(programCollection);
823 }
824 }
825
supportedCheck(Context & context,CaseDefinition caseDef)826 void supportedCheck(Context &context, CaseDefinition caseDef)
827 {
828 if (!subgroups::isSubgroupSupported(context))
829 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
830
831 if (!subgroups::isSubgroupFeatureSupportedForDevice(context, SUBGROUP_FEATURE_ARITHMETIC_BIT))
832 {
833 TCU_THROW(NotSupportedError, "Device does not support subgroup arithmetic operations");
834 }
835
836 if (subgroups::isDoubleFormat(caseDef.format) && !subgroups::isDoubleSupportedForDevice(context))
837 {
838 TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
839 }
840 }
841
noSSBOtest(Context & context,const CaseDefinition caseDef)842 tcu::TestStatus noSSBOtest(Context &context, const CaseDefinition caseDef)
843 {
844 if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
845 {
846 if (subgroups::areSubgroupOperationsRequiredForStage(caseDef.shaderStage))
847 {
848 return tcu::TestStatus::fail("Shader stage " + subgroups::getShaderStageName(caseDef.shaderStage) +
849 " is required to support subgroup operations!");
850 }
851 else
852 {
853 TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
854 }
855 }
856
857 subgroups::SSBOData inputData;
858 inputData.format = caseDef.format;
859 inputData.layout = subgroups::SSBOData::LayoutStd140;
860 inputData.numElements = subgroups::maxSupportedSubgroupSize();
861 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
862 inputData.binding = 0u;
863
864 if (SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
865 return subgroups::makeVertexFrameBufferTest(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
866 else if (SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
867 return subgroups::makeGeometryFrameBufferTest(context, FORMAT_R32_UINT, &inputData, 1,
868 checkVertexPipelineStages);
869 else if (SHADER_STAGE_TESS_CONTROL_BIT == caseDef.shaderStage)
870 return subgroups::makeTessellationEvaluationFrameBufferTest(
871 context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, SHADER_STAGE_TESS_CONTROL_BIT);
872 else if (SHADER_STAGE_TESS_EVALUATION_BIT == caseDef.shaderStage)
873 return subgroups::makeTessellationEvaluationFrameBufferTest(
874 context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, SHADER_STAGE_TESS_EVALUATION_BIT);
875 else
876 TCU_THROW(InternalError, "Unhandled shader stage");
877 }
878
checkShaderStages(Context & context,const CaseDefinition & caseDef)879 bool checkShaderStages(Context &context, const CaseDefinition &caseDef)
880 {
881 if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
882 {
883 if (subgroups::areSubgroupOperationsRequiredForStage(caseDef.shaderStage))
884 {
885 return false;
886 }
887 else
888 {
889 TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
890 }
891 }
892 return true;
893 }
894
test(Context & context,const CaseDefinition caseDef)895 tcu::TestStatus test(Context &context, const CaseDefinition caseDef)
896 {
897 if (SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
898 {
899 if (!checkShaderStages(context, caseDef))
900 {
901 return tcu::TestStatus::fail("Shader stage " + subgroups::getShaderStageName(caseDef.shaderStage) +
902 " is required to support subgroup operations!");
903 }
904 subgroups::SSBOData inputData;
905 inputData.format = caseDef.format;
906 inputData.layout = subgroups::SSBOData::LayoutStd430;
907 inputData.numElements = subgroups::maxSupportedSubgroupSize();
908 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
909 inputData.binding = 1u;
910
911 return subgroups::makeComputeTest(context, FORMAT_R32_UINT, &inputData, 1, checkComputeStage);
912 }
913 else
914 {
915 int supportedStages = context.getDeqpContext().getContextInfo().getInt(GL_SUBGROUP_SUPPORTED_STAGES_KHR);
916
917 ShaderStageFlags stages = (ShaderStageFlags)(caseDef.shaderStage & supportedStages);
918
919 if (SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
920 {
921 if ((stages & SHADER_STAGE_FRAGMENT_BIT) == 0)
922 TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
923 else
924 stages = SHADER_STAGE_FRAGMENT_BIT;
925 }
926
927 if ((ShaderStageFlags)0u == stages)
928 TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
929
930 subgroups::SSBOData inputData;
931 inputData.format = caseDef.format;
932 inputData.layout = subgroups::SSBOData::LayoutStd430;
933 inputData.numElements = subgroups::maxSupportedSubgroupSize();
934 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
935 inputData.binding = 4u;
936 inputData.stages = stages;
937
938 return subgroups::allStages(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, stages);
939 }
940 }
941 } // namespace
942
createSubgroupsArithmeticTests(deqp::Context & testCtx)943 deqp::TestCaseGroup *createSubgroupsArithmeticTests(deqp::Context &testCtx)
944 {
945 de::MovePtr<deqp::TestCaseGroup> graphicGroup(
946 new deqp::TestCaseGroup(testCtx, "graphics", "Subgroup arithmetic category tests: graphics"));
947 de::MovePtr<deqp::TestCaseGroup> computeGroup(
948 new deqp::TestCaseGroup(testCtx, "compute", "Subgroup arithmetic category tests: compute"));
949 de::MovePtr<deqp::TestCaseGroup> framebufferGroup(
950 new deqp::TestCaseGroup(testCtx, "framebuffer", "Subgroup arithmetic category tests: framebuffer"));
951
952 const ShaderStageFlags stages[] = {
953 SHADER_STAGE_VERTEX_BIT,
954 SHADER_STAGE_TESS_EVALUATION_BIT,
955 SHADER_STAGE_TESS_CONTROL_BIT,
956 SHADER_STAGE_GEOMETRY_BIT,
957 };
958
959 const Format formats[] = {
960 FORMAT_R32_SINT, FORMAT_R32G32_SINT, FORMAT_R32G32B32_SINT, FORMAT_R32G32B32A32_SINT,
961 FORMAT_R32_UINT, FORMAT_R32G32_UINT, FORMAT_R32G32B32_UINT, FORMAT_R32G32B32A32_UINT,
962 FORMAT_R32_SFLOAT, FORMAT_R32G32_SFLOAT, FORMAT_R32G32B32_SFLOAT, FORMAT_R32G32B32A32_SFLOAT,
963 FORMAT_R64_SFLOAT, FORMAT_R64G64_SFLOAT, FORMAT_R64G64B64_SFLOAT, FORMAT_R64G64B64A64_SFLOAT,
964 FORMAT_R32_BOOL, FORMAT_R32G32_BOOL, FORMAT_R32G32B32_BOOL, FORMAT_R32G32B32A32_BOOL,
965 };
966
967 for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
968 {
969 const Format format = formats[formatIndex];
970
971 for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
972 {
973 bool isBool = false;
974 bool isFloat = false;
975
976 switch (format)
977 {
978 default:
979 break;
980 case FORMAT_R32_SFLOAT:
981 case FORMAT_R32G32_SFLOAT:
982 case FORMAT_R32G32B32_SFLOAT:
983 case FORMAT_R32G32B32A32_SFLOAT:
984 case FORMAT_R64_SFLOAT:
985 case FORMAT_R64G64_SFLOAT:
986 case FORMAT_R64G64B64_SFLOAT:
987 case FORMAT_R64G64B64A64_SFLOAT:
988 isFloat = true;
989 break;
990 case FORMAT_R32_BOOL:
991 case FORMAT_R32G32_BOOL:
992 case FORMAT_R32G32B32_BOOL:
993 case FORMAT_R32G32B32A32_BOOL:
994 isBool = true;
995 break;
996 }
997
998 bool isBitwiseOp = false;
999
1000 switch (opTypeIndex)
1001 {
1002 default:
1003 break;
1004 case OPTYPE_AND:
1005 case OPTYPE_INCLUSIVE_AND:
1006 case OPTYPE_EXCLUSIVE_AND:
1007 case OPTYPE_OR:
1008 case OPTYPE_INCLUSIVE_OR:
1009 case OPTYPE_EXCLUSIVE_OR:
1010 case OPTYPE_XOR:
1011 case OPTYPE_INCLUSIVE_XOR:
1012 case OPTYPE_EXCLUSIVE_XOR:
1013 isBitwiseOp = true;
1014 break;
1015 }
1016
1017 if (isFloat && isBitwiseOp)
1018 {
1019 // Skip float with bitwise category.
1020 continue;
1021 }
1022
1023 if (isBool && !isBitwiseOp)
1024 {
1025 // Skip bool when its not the bitwise category.
1026 continue;
1027 }
1028 std::string op = getOpTypeName(opTypeIndex);
1029
1030 {
1031 const CaseDefinition caseDef = {opTypeIndex, SHADER_STAGE_COMPUTE_BIT, format};
1032 SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(
1033 computeGroup.get(), de::toLower(op) + "_" + subgroups::getFormatNameForGLSL(format), "",
1034 supportedCheck, initPrograms, test, caseDef);
1035 }
1036
1037 {
1038 const CaseDefinition caseDef = {opTypeIndex, SHADER_STAGE_ALL_GRAPHICS, format};
1039 SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(
1040 graphicGroup.get(), de::toLower(op) + "_" + subgroups::getFormatNameForGLSL(format), "",
1041 supportedCheck, initPrograms, test, caseDef);
1042 }
1043
1044 for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
1045 {
1046 const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex], format};
1047 SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(
1048 framebufferGroup.get(),
1049 de::toLower(op) + "_" + subgroups::getFormatNameForGLSL(format) + "_" +
1050 getShaderStageName(caseDef.shaderStage),
1051 "", supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
1052 }
1053 }
1054 }
1055
1056 de::MovePtr<deqp::TestCaseGroup> group(
1057 new deqp::TestCaseGroup(testCtx, "arithmetic", "Subgroup arithmetic category tests"));
1058
1059 group->addChild(graphicGroup.release());
1060 group->addChild(computeGroup.release());
1061 group->addChild(framebufferGroup.release());
1062
1063 return group.release();
1064 }
1065
1066 } // namespace subgroups
1067 } // namespace glc
1068