1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2021 The Khronos Group Inc.
6  * Copyright (c) 2021 Valve Corporation.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  *//*!
21  * \file
22  * \brief Mesh Shader Query Tests for VK_EXT_mesh_shader
23  *//*--------------------------------------------------------------------*/
24 
25 #include "vktMeshShaderQueryTestsEXT.hpp"
26 #include "vktMeshShaderUtil.hpp"
27 #include "vktTestCase.hpp"
28 #include "vktTestCaseUtil.hpp"
29 
30 #include "vkImageWithMemory.hpp"
31 #include "vkBufferWithMemory.hpp"
32 #include "vkImageUtil.hpp"
33 #include "vkTypeUtil.hpp"
34 #include "vkObjUtil.hpp"
35 #include "vkCmdUtil.hpp"
36 #include "vkBarrierUtil.hpp"
37 
38 #include "tcuImageCompare.hpp"
39 #include "tcuTextureUtil.hpp"
40 
41 #include "deRandom.hpp"
42 #include "deUniquePtr.hpp"
43 
44 #include <vector>
45 #include <algorithm>
46 #include <sstream>
47 #include <string>
48 #include <numeric>
49 #include <array>
50 #include <limits>
51 
52 namespace vkt
53 {
54 namespace MeshShader
55 {
56 
57 namespace
58 {
59 
60 using namespace vk;
61 
62 using BufferWithMemoryPtr = de::MovePtr<BufferWithMemory>;
63 
64 constexpr uint32_t kImageWidth            = 32u;
65 constexpr uint32_t kMeshWorkGroupsPerCall = 4u;
66 constexpr uint32_t kTaskWorkGroupsPerCall = 2u;
67 constexpr uint32_t kMeshWorkGroupsPerTask = kMeshWorkGroupsPerCall / kTaskWorkGroupsPerCall;
68 
69 constexpr uint32_t kMeshLocalInvocationsX = 10u;
70 constexpr uint32_t kMeshLocalInvocationsY = 4u;
71 constexpr uint32_t kMeshLocalInvocationsZ = 1u;
72 constexpr uint32_t kMeshLocalInvocations  = kMeshLocalInvocationsX * kMeshLocalInvocationsY * kMeshLocalInvocationsZ;
73 
74 constexpr uint32_t kTaskLocalInvocationsX = 1u;
75 constexpr uint32_t kTaskLocalInvocationsY = 4u;
76 constexpr uint32_t kTaskLocalInvocationsZ = 6u;
77 constexpr uint32_t kTaskLocalInvocations  = kTaskLocalInvocationsX * kTaskLocalInvocationsY * kTaskLocalInvocationsZ;
78 
79 constexpr VkDeviceSize k64sz = static_cast<VkDeviceSize>(sizeof(uint64_t));
80 constexpr VkDeviceSize k32sz = static_cast<VkDeviceSize>(sizeof(uint32_t));
81 
82 enum class QueryType
83 {
84     PRIMITIVES = 0,
85     TASK_INVOCATIONS,
86     MESH_INVOCATIONS,
87 };
88 
89 enum class DrawCallType
90 {
91     DIRECT = 0,
92     INDIRECT,
93     INDIRECT_WITH_COUNT,
94 };
95 
96 enum class GeometryType
97 {
98     POINTS = 0,
99     LINES,
100     TRIANGLES,
101 };
102 
toString(GeometryType geometryType)103 std::string toString(GeometryType geometryType)
104 {
105     std::string result;
106     switch (geometryType)
107     {
108     case GeometryType::POINTS:
109         result = "points";
110         break;
111     case GeometryType::LINES:
112         result = "lines";
113         break;
114     case GeometryType::TRIANGLES:
115         result = "triangles";
116         break;
117     default:
118         DE_ASSERT(false);
119         break;
120     }
121     return result;
122 }
123 
vertsPerPrimitive(GeometryType geometryType)124 uint32_t vertsPerPrimitive(GeometryType geometryType)
125 {
126     uint32_t vertices = 0u;
127     switch (geometryType)
128     {
129     case GeometryType::POINTS:
130         vertices = 1u;
131         break;
132     case GeometryType::LINES:
133         vertices = 2u;
134         break;
135     case GeometryType::TRIANGLES:
136         vertices = 3u;
137         break;
138     default:
139         DE_ASSERT(false);
140         break;
141     }
142     return vertices;
143 }
144 
145 enum class ResetCase
146 {
147     NONE = 0,
148     NONE_WITH_HOST, // After checking results normally, reset query from the host and verify availability.
149     BEFORE_ACCESS,
150     AFTER_ACCESS,
151 };
152 
153 enum class AccessMethod
154 {
155     COPY = 0,
156     GET,
157 };
158 
checkGetQueryRes(VkResult result,bool allowNotReady)159 void checkGetQueryRes(VkResult result, bool allowNotReady)
160 {
161     if (result == VK_SUCCESS || (result == VK_NOT_READY && allowNotReady))
162         return;
163 
164     const auto msg = getResultStr(result);
165     TCU_FAIL(msg.toString());
166 }
167 
168 // The pseudrandom number generator will be used in the test case and test instance, so we use two seeds per case.
getNewSeed(void)169 uint32_t getNewSeed(void)
170 {
171     static uint32_t seed  = 1656078156u;
172     uint32_t returnedSeed = seed;
173     seed += 2u;
174     return returnedSeed;
175 }
176 
177 struct TestParams
178 {
179     uint32_t randomSeed;
180     std::vector<QueryType> queryTypes;
181     std::vector<uint32_t> drawBlocks;
182     DrawCallType drawCall;
183     GeometryType geometry;
184     ResetCase resetType;
185     AccessMethod access;
186     bool use64Bits;
187     bool availabilityBit;
188     bool waitBit;
189     bool useTaskShader;
190     bool insideRenderPass;
191     bool useSecondary;
192     bool multiView;
193 
swapvkt::MeshShader::__anon3fc888fa0111::TestParams194     void swap(TestParams &other)
195     {
196         std::swap(randomSeed, other.randomSeed);
197         queryTypes.swap(other.queryTypes);
198         drawBlocks.swap(other.drawBlocks);
199         std::swap(drawCall, other.drawCall);
200         std::swap(geometry, other.geometry);
201         std::swap(resetType, other.resetType);
202         std::swap(access, other.access);
203         std::swap(use64Bits, other.use64Bits);
204         std::swap(availabilityBit, other.availabilityBit);
205         std::swap(waitBit, other.waitBit);
206         std::swap(useTaskShader, other.useTaskShader);
207         std::swap(insideRenderPass, other.insideRenderPass);
208         std::swap(useSecondary, other.useSecondary);
209         std::swap(multiView, other.multiView);
210     }
211 
TestParamsvkt::MeshShader::__anon3fc888fa0111::TestParams212     TestParams()
213         : randomSeed(getNewSeed())
214         , queryTypes()
215         , drawBlocks()
216         , drawCall(DrawCallType::DIRECT)
217         , geometry(GeometryType::POINTS)
218         , resetType(ResetCase::NONE)
219         , access(AccessMethod::COPY)
220         , use64Bits(false)
221         , availabilityBit(false)
222         , waitBit(false)
223         , useTaskShader(false)
224         , insideRenderPass(false)
225         , useSecondary(false)
226         , multiView(false)
227     {
228     }
229 
TestParamsvkt::MeshShader::__anon3fc888fa0111::TestParams230     TestParams(const TestParams &other)
231         : randomSeed(other.randomSeed)
232         , queryTypes(other.queryTypes)
233         , drawBlocks(other.drawBlocks)
234         , drawCall(other.drawCall)
235         , geometry(other.geometry)
236         , resetType(other.resetType)
237         , access(other.access)
238         , use64Bits(other.use64Bits)
239         , availabilityBit(other.availabilityBit)
240         , waitBit(other.waitBit)
241         , useTaskShader(other.useTaskShader)
242         , insideRenderPass(other.insideRenderPass)
243         , useSecondary(other.useSecondary)
244         , multiView(other.multiView)
245     {
246     }
247 
TestParamsvkt::MeshShader::__anon3fc888fa0111::TestParams248     TestParams(TestParams &&other) : TestParams()
249     {
250         this->swap(other);
251     }
252 
getTotalDrawCountvkt::MeshShader::__anon3fc888fa0111::TestParams253     uint32_t getTotalDrawCount(void) const
254     {
255         const uint32_t callCount = std::accumulate(drawBlocks.begin(), drawBlocks.end(), 0u);
256         return callCount;
257     }
258 
getImageHeightvkt::MeshShader::__anon3fc888fa0111::TestParams259     uint32_t getImageHeight(void) const
260     {
261         return getTotalDrawCount() * kMeshWorkGroupsPerCall;
262     }
263 
264     // The goal is dispatching 4 mesh work groups per draw call in total. When not using task shaders, we dispatch that number
265     // directly. When using task shaders, we dispatch 2 task work groups that will dispatch 2 mesh work groups each. The axis will
266     // be pseudorandomly chosen in each case.
getDrawGroupCountvkt::MeshShader::__anon3fc888fa0111::TestParams267     uint32_t getDrawGroupCount(void) const
268     {
269         return (useTaskShader ? kTaskWorkGroupsPerCall : kMeshWorkGroupsPerCall);
270     }
271 
272     // Gets the right query result flags for the current parameters.
getQueryResultFlagsvkt::MeshShader::__anon3fc888fa0111::TestParams273     VkQueryResultFlags getQueryResultFlags(void) const
274     {
275         const VkQueryResultFlags queryResultFlags =
276             ((use64Bits ? VK_QUERY_RESULT_64_BIT : 0) | (availabilityBit ? VK_QUERY_RESULT_WITH_AVAILABILITY_BIT : 0) |
277              (waitBit ? VK_QUERY_RESULT_WAIT_BIT : VK_QUERY_RESULT_PARTIAL_BIT));
278         return queryResultFlags;
279     }
280 
281     // Queries will be inherited if they are started outside of a render pass and using secondary command buffers.
282     // - If secondary command buffers are not used, nothing will be inherited.
283     // - If secondary command buffers are used but queries start inside of a render pass, queries will run entirely inside the secondary command buffer.
areQueriesInheritedvkt::MeshShader::__anon3fc888fa0111::TestParams284     bool areQueriesInherited(void) const
285     {
286         return (useSecondary && !insideRenderPass);
287     }
288 
289 protected:
hasQueryTypevkt::MeshShader::__anon3fc888fa0111::TestParams290     bool hasQueryType(QueryType queryType) const
291     {
292         return de::contains(queryTypes.begin(), queryTypes.end(), queryType);
293     }
294 
295 public:
hasPrimitivesQueryvkt::MeshShader::__anon3fc888fa0111::TestParams296     bool hasPrimitivesQuery(void) const
297     {
298         return hasQueryType(QueryType::PRIMITIVES);
299     }
300 
hasMeshInvStatvkt::MeshShader::__anon3fc888fa0111::TestParams301     bool hasMeshInvStat(void) const
302     {
303         return hasQueryType(QueryType::MESH_INVOCATIONS);
304     }
305 
hasTaskInvStatvkt::MeshShader::__anon3fc888fa0111::TestParams306     bool hasTaskInvStat(void) const
307     {
308         return hasQueryType(QueryType::TASK_INVOCATIONS);
309     }
310 
311     struct QuerySizesAndOffsets
312     {
313         VkDeviceSize queryItemSize;
314         VkDeviceSize primitivesQuerySize;
315         VkDeviceSize statsQuerySize;
316         VkDeviceSize statsQueryOffset;
317     };
318 
getViewCountvkt::MeshShader::__anon3fc888fa0111::TestParams319     uint32_t getViewCount(void) const
320     {
321         return (multiView ? 2u : 1u);
322     }
323 
getQuerySizesAndOffsetsvkt::MeshShader::__anon3fc888fa0111::TestParams324     QuerySizesAndOffsets getQuerySizesAndOffsets(void) const
325     {
326         QuerySizesAndOffsets sizesAndOffsets;
327         const VkDeviceSize extraQueryItems = (availabilityBit ? 1ull : 0ull);
328         const VkDeviceSize viewMultiplier  = getViewCount();
329 
330         sizesAndOffsets.queryItemSize       = (use64Bits ? k64sz : k32sz);
331         sizesAndOffsets.primitivesQuerySize = (extraQueryItems + 1ull) * sizesAndOffsets.queryItemSize;
332         sizesAndOffsets.statsQuerySize =
333             (extraQueryItems + (hasTaskInvStat() ? 1ull : 0ull) + (hasMeshInvStat() ? 1ull : 0ull)) *
334             sizesAndOffsets.queryItemSize;
335         sizesAndOffsets.statsQueryOffset =
336             (hasPrimitivesQuery() ? (sizesAndOffsets.primitivesQuerySize * viewMultiplier) : 0ull);
337 
338         return sizesAndOffsets;
339     }
340 };
341 
342 class MeshQueryCase : public vkt::TestCase
343 {
344 public:
MeshQueryCase(tcu::TestContext & testCtx,const std::string & name,TestParams && params)345     MeshQueryCase(tcu::TestContext &testCtx, const std::string &name, TestParams &&params)
346         : vkt::TestCase(testCtx, name)
347         , m_params(std::move(params))
348     {
349     }
~MeshQueryCase(void)350     virtual ~MeshQueryCase(void)
351     {
352     }
353 
354     void initPrograms(vk::SourceCollections &programCollection) const override;
355     TestInstance *createInstance(Context &context) const override;
356     void checkSupport(Context &context) const override;
357 
358 protected:
359     TestParams m_params;
360 };
361 
362 class MeshQueryInstance : public vkt::TestInstance
363 {
364 public:
MeshQueryInstance(Context & context,const TestParams & params)365     MeshQueryInstance(Context &context, const TestParams &params)
366         : vkt::TestInstance(context)
367         , m_params(&params)
368         , m_rnd(params.randomSeed + 1u) // Add 1 to make the instance seed different.
369         , m_indirectBuffer()
370         , m_indirectCountBuffer()
371         , m_fence(createFence(context.getDeviceInterface(), context.getDevice()))
372     {
373     }
~MeshQueryInstance(void)374     virtual ~MeshQueryInstance(void)
375     {
376     }
377 
378     Move<VkRenderPass> makeCustomRenderPass(const DeviceInterface &vkd, VkDevice device, uint32_t layerCount,
379                                             VkFormat format);
380     tcu::TestStatus iterate(void) override;
381 
382 protected:
383     VkDrawMeshTasksIndirectCommandEXT getRandomShuffle(uint32_t groupCount);
384     void recordDraws(const VkCommandBuffer cmdBuffer, const VkPipeline pipeline, const VkPipelineLayout layout);
385     void beginFirstQueries(const VkCommandBuffer cmdBuffer, const std::vector<VkQueryPool> &queryPools) const;
386     void endFirstQueries(const VkCommandBuffer cmdBuffer, const std::vector<VkQueryPool> &queryPools) const;
387     void resetFirstQueries(const VkCommandBuffer cmdBuffer, const std::vector<VkQueryPool> &queryPools,
388                            const uint32_t queryCount) const;
389     void submitCommands(const VkCommandBuffer cmdBuffer) const;
390     void waitForFence() const;
391 
392     const TestParams *m_params;
393     de::Random m_rnd;
394     BufferWithMemoryPtr m_indirectBuffer;
395     BufferWithMemoryPtr m_indirectCountBuffer;
396     Move<VkFence> m_fence;
397 };
398 
initPrograms(vk::SourceCollections & programCollection) const399 void MeshQueryCase::initPrograms(vk::SourceCollections &programCollection) const
400 {
401     const auto meshBuildOpts = getMinMeshEXTBuildOptions(programCollection.usedVulkanVersion);
402     const auto imageHeight   = m_params.getImageHeight();
403 
404     const std::string taskDataDecl = "struct TaskData {\n"
405                                      "    uint branch[" +
406                                      std::to_string(kTaskLocalInvocations) +
407                                      "];\n"
408                                      "    uint drawIndex;\n"
409                                      "};\n"
410                                      "taskPayloadSharedEXT TaskData td;\n";
411 
412     std::ostringstream frag;
413     frag << "#version 460\n"
414          << (m_params.multiView ? "#extension GL_EXT_multiview : enable\n" : "")
415          << "layout (location=0) out vec4 outColor;\n"
416          << "void main (void) { outColor = vec4(0.0, " << (m_params.multiView ? "float(gl_ViewIndex)" : "0.0")
417          << ", 1.0, 1.0); }\n";
418     programCollection.glslSources.add("frag") << glu::FragmentSource(frag.str());
419 
420     std::ostringstream mesh;
421     mesh << "#version 460\n"
422          << "#extension GL_EXT_mesh_shader : enable\n"
423          << "\n"
424          << "layout (local_size_x=" << kMeshLocalInvocationsX << ", local_size_y=" << kMeshLocalInvocationsY
425          << ", local_size_z=" << kMeshLocalInvocationsZ << ") in;\n"
426          << "layout (" << toString(m_params.geometry) << ") out;\n"
427          << "layout (max_vertices=256, max_primitives=256) out;\n"
428          << "\n"
429          << "layout (push_constant, std430) uniform PushConstants {\n"
430          << "    uint prevDrawCalls;\n"
431          << "} pc;\n"
432          << "\n";
433 
434     if (m_params.useTaskShader)
435         mesh << taskDataDecl << "\n";
436 
437     mesh << "\n"
438          << "shared uint currentCol;\n"
439          << "\n"
440          << "void main (void)\n"
441          << "{\n"
442          << "    atomicExchange(currentCol, 0u);\n"
443          << "    barrier();\n"
444          << "\n"
445          << "    const uint colCount = uint(" << kImageWidth << ");\n"
446          << "    const uint rowCount = uint(" << imageHeight << ");\n"
447          << "    const uint rowsPerDraw = uint(" << kMeshWorkGroupsPerCall << ");\n"
448          << "\n"
449          << "    const float pixWidth = 2.0 / float(colCount);\n"
450          << "    const float pixHeight = 2.0 / float(rowCount);\n"
451          << "    const float horDelta = pixWidth / 4.0;\n"
452          << "    const float verDelta = (pixHeight * 3.0) / 8.0;\n"
453          << "\n"
454          << "    const uint DrawIndex = " << (m_params.useTaskShader ? "td.drawIndex" : "uint(gl_DrawID)") << ";\n"
455          << "    const uint currentWGIndex = ("
456          << (m_params.useTaskShader ?
457                  "2u * td.branch[min(gl_LocalInvocationIndex, " + std::to_string(kTaskLocalInvocations - 1u) + ")] + " :
458                  "")
459          << "gl_WorkGroupID.x + gl_WorkGroupID.y + gl_WorkGroupID.z);\n"
460          << "    const uint row = (pc.prevDrawCalls + DrawIndex) * rowsPerDraw + currentWGIndex;\n"
461          << "    const uint vertsPerPrimitive = " << vertsPerPrimitive(m_params.geometry) << ";\n"
462          << "\n"
463          << "    SetMeshOutputsEXT(colCount * vertsPerPrimitive, colCount);\n"
464          << "\n"
465          << "    const uint col = atomicAdd(currentCol, 1);\n"
466          << "    if (col < colCount)\n"
467          << "    {\n"
468          << "        const float xCenter = (float(col) + 0.5) / colCount * 2.0 - 1.0;\n"
469          << "        const float yCenter = (float(row) + 0.5) / rowCount * 2.0 - 1.0;\n"
470          << "\n"
471          << "        const uint firstVert = col * vertsPerPrimitive;\n"
472          << "\n";
473 
474     switch (m_params.geometry)
475     {
476     case GeometryType::POINTS:
477         mesh << "        gl_MeshVerticesEXT[firstVert].gl_Position = vec4(xCenter, yCenter, 0.0, 1.0);\n"
478              << "        gl_MeshVerticesEXT[firstVert].gl_PointSize = 1.0;\n"
479              << "        gl_PrimitivePointIndicesEXT[col] = firstVert;\n";
480         break;
481     case GeometryType::LINES:
482         mesh << "        gl_MeshVerticesEXT[firstVert + 0].gl_Position = vec4(xCenter - horDelta, yCenter, 0.0, 1.0);\n"
483              << "        gl_MeshVerticesEXT[firstVert + 1].gl_Position = vec4(xCenter + horDelta, yCenter - verDelta, "
484                 "0.0, 1.0);\n"
485              << "        gl_PrimitiveLineIndicesEXT[col] = uvec2(firstVert, firstVert + 1);\n";
486         break;
487     case GeometryType::TRIANGLES:
488         mesh << "        gl_MeshVerticesEXT[firstVert + 0].gl_Position = vec4(xCenter           , yCenter - verDelta, "
489                 "0.0, 1.0);\n"
490              << "        gl_MeshVerticesEXT[firstVert + 1].gl_Position = vec4(xCenter - horDelta, yCenter + verDelta, "
491                 "0.0, 1.0);\n"
492              << "        gl_MeshVerticesEXT[firstVert + 2].gl_Position = vec4(xCenter + horDelta, yCenter + verDelta, "
493                 "0.0, 1.0);\n"
494              << "        gl_PrimitiveTriangleIndicesEXT[col] = uvec3(firstVert, firstVert + 1, firstVert + 2);\n";
495         break;
496     default:
497         DE_ASSERT(false);
498         break;
499     }
500 
501     mesh << "    }\n"
502          << "}\n";
503     programCollection.glslSources.add("mesh") << glu::MeshSource(mesh.str()) << meshBuildOpts;
504 
505     if (m_params.useTaskShader)
506     {
507         // See TestParams::getDrawGroupCount().
508         de::Random rnd(m_params.randomSeed);
509         std::vector<uint32_t> meshTaskCount{kMeshWorkGroupsPerTask, 1u, 1u};
510 
511         rnd.shuffle(meshTaskCount.begin(), meshTaskCount.end());
512 
513         std::ostringstream task;
514         task << "#version 460\n"
515              << "#extension GL_EXT_mesh_shader : enable\n"
516              << "\n"
517              << "layout (local_size_x=" << kTaskLocalInvocationsX << ", local_size_y=" << kTaskLocalInvocationsY
518              << ", local_size_z=" << kTaskLocalInvocationsZ << ") in;\n"
519              << "\n"
520              << taskDataDecl << "\n"
521              << "void main ()\n"
522              << "{\n"
523              << "   td.branch[gl_LocalInvocationIndex] = gl_WorkGroupID.x + gl_WorkGroupID.y + gl_WorkGroupID.z;\n"
524              << "   td.drawIndex = uint(gl_DrawID);\n"
525              << "   EmitMeshTasksEXT(" << meshTaskCount.at(0) << ", " << meshTaskCount.at(1) << ", "
526              << meshTaskCount.at(2) << ");\n"
527              << "}\n";
528         programCollection.glslSources.add("task") << glu::TaskSource(task.str()) << meshBuildOpts;
529     }
530 }
531 
createInstance(Context & context) const532 TestInstance *MeshQueryCase::createInstance(Context &context) const
533 {
534     return new MeshQueryInstance(context, m_params);
535 }
536 
checkSupport(Context & context) const537 void MeshQueryCase::checkSupport(Context &context) const
538 {
539     checkTaskMeshShaderSupportEXT(context, m_params.useTaskShader /*requireTask*/, true /*requireMesh*/);
540 
541     const auto &meshFeatures = context.getMeshShaderFeaturesEXT();
542 
543     if (!m_params.queryTypes.empty())
544     {
545         if (!meshFeatures.meshShaderQueries)
546             TCU_THROW(NotSupportedError, "meshShaderQueries not supported");
547     }
548 
549     if (m_params.areQueriesInherited())
550         context.requireDeviceCoreFeature(DEVICE_CORE_FEATURE_INHERITED_QUERIES);
551 
552     if (m_params.resetType == ResetCase::NONE_WITH_HOST)
553         context.requireDeviceFunctionality("VK_EXT_host_query_reset");
554 
555     if (m_params.multiView)
556     {
557         if (!meshFeatures.multiviewMeshShader)
558             TCU_THROW(NotSupportedError, "multiviewMeshShader not supported");
559 
560         const auto &meshProperties = context.getMeshShaderPropertiesEXT();
561         if (meshProperties.maxMeshMultiviewViewCount < m_params.getViewCount())
562             TCU_THROW(NotSupportedError, "maxMeshMultiviewViewCount too low");
563     }
564 }
565 
getRandomShuffle(uint32_t groupCount)566 VkDrawMeshTasksIndirectCommandEXT MeshQueryInstance::getRandomShuffle(uint32_t groupCount)
567 {
568     std::array<uint32_t, 3> counts{groupCount, 1u, 1u};
569     m_rnd.shuffle(counts.begin(), counts.end());
570 
571     const VkDrawMeshTasksIndirectCommandEXT result{counts[0], counts[1], counts[2]};
572     return result;
573 }
574 
recordDraws(const VkCommandBuffer cmdBuffer,const VkPipeline pipeline,const VkPipelineLayout layout)575 void MeshQueryInstance::recordDraws(const VkCommandBuffer cmdBuffer, const VkPipeline pipeline,
576                                     const VkPipelineLayout layout)
577 {
578     const auto &vkd           = m_context.getDeviceInterface();
579     const auto device         = m_context.getDevice();
580     auto &alloc               = m_context.getDefaultAllocator();
581     const auto drawGroupCount = m_params->getDrawGroupCount();
582     const auto pcSize         = static_cast<uint32_t>(sizeof(uint32_t));
583 
584     vkd.cmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, pipeline);
585 
586     if (m_params->drawCall == DrawCallType::DIRECT)
587     {
588         uint32_t totalDrawCalls = 0u;
589         for (const auto &blockSize : m_params->drawBlocks)
590         {
591             for (uint32_t drawIdx = 0u; drawIdx < blockSize; ++drawIdx)
592             {
593                 const auto counts = getRandomShuffle(drawGroupCount);
594                 vkd.cmdPushConstants(cmdBuffer, layout, VK_SHADER_STAGE_MESH_BIT_EXT, 0u, pcSize, &totalDrawCalls);
595                 vkd.cmdDrawMeshTasksEXT(cmdBuffer, counts.groupCountX, counts.groupCountY, counts.groupCountZ);
596                 ++totalDrawCalls;
597             }
598         }
599     }
600     else if (m_params->drawCall == DrawCallType::INDIRECT || m_params->drawCall == DrawCallType::INDIRECT_WITH_COUNT)
601     {
602         if (m_params->drawBlocks.empty())
603             return;
604 
605         const auto totalDrawCount = m_params->getTotalDrawCount();
606         const auto cmdSize        = static_cast<uint32_t>(sizeof(VkDrawMeshTasksIndirectCommandEXT));
607 
608         std::vector<VkDrawMeshTasksIndirectCommandEXT> indirectCommands;
609         indirectCommands.reserve(totalDrawCount);
610 
611         for (uint32_t i = 0u; i < totalDrawCount; ++i)
612             indirectCommands.emplace_back(getRandomShuffle(drawGroupCount));
613 
614         // Copy the array to a host-visible buffer.
615         // Note: We make sure all indirect buffers are allocated with a non-zero size by adding cmdSize to the expected size.
616         // Size of buffer must be greater than stride * (maxDrawCount - 1) + offset + sizeof(VkDrawMeshTasksIndirectCommandEXT) so we multiply by 2
617         const auto indirectBufferSize       = de::dataSize(indirectCommands);
618         const auto indirectBufferCreateInfo = makeBufferCreateInfo(
619             static_cast<VkDeviceSize>((indirectBufferSize + cmdSize) * 2), VK_BUFFER_USAGE_INDIRECT_BUFFER_BIT);
620 
621         m_indirectBuffer = BufferWithMemoryPtr(
622             new BufferWithMemory(vkd, device, alloc, indirectBufferCreateInfo, MemoryRequirement::HostVisible));
623         auto &indirectBufferAlloc = m_indirectBuffer->getAllocation();
624         void *indirectBufferData  = indirectBufferAlloc.getHostPtr();
625 
626         deMemcpy(indirectBufferData, indirectCommands.data(), indirectBufferSize);
627         flushAlloc(vkd, device, indirectBufferAlloc);
628 
629         if (m_params->drawCall == DrawCallType::INDIRECT)
630         {
631             uint32_t accumulatedCount = 0u;
632 
633             for (const auto &blockSize : m_params->drawBlocks)
634             {
635                 const auto offset = static_cast<VkDeviceSize>(cmdSize * accumulatedCount);
636                 vkd.cmdPushConstants(cmdBuffer, layout, VK_SHADER_STAGE_MESH_BIT_EXT, 0u, pcSize, &accumulatedCount);
637                 vkd.cmdDrawMeshTasksIndirectEXT(cmdBuffer, m_indirectBuffer->get(), offset, blockSize, cmdSize);
638                 accumulatedCount += blockSize;
639             }
640         }
641         else
642         {
643             // Copy the "block sizes" to a host-visible buffer.
644             const auto indirectCountBufferSize       = de::dataSize(m_params->drawBlocks);
645             const auto indirectCountBufferCreateInfo = makeBufferCreateInfo(
646                 static_cast<VkDeviceSize>(indirectCountBufferSize + cmdSize), VK_BUFFER_USAGE_INDIRECT_BUFFER_BIT);
647 
648             m_indirectCountBuffer          = BufferWithMemoryPtr(new BufferWithMemory(
649                 vkd, device, alloc, indirectCountBufferCreateInfo, MemoryRequirement::HostVisible));
650             auto &indirectCountBufferAlloc = m_indirectCountBuffer->getAllocation();
651             void *indirectCountBufferData  = indirectCountBufferAlloc.getHostPtr();
652 
653             deMemcpy(indirectCountBufferData, m_params->drawBlocks.data(), indirectCountBufferSize);
654             flushAlloc(vkd, device, indirectCountBufferAlloc);
655 
656             // Record indirect draws with count.
657             uint32_t accumulatedCount = 0u;
658 
659             for (uint32_t countIdx = 0u; countIdx < m_params->drawBlocks.size(); ++countIdx)
660             {
661                 const auto &blockSize  = m_params->drawBlocks.at(countIdx);
662                 const auto offset      = static_cast<VkDeviceSize>(cmdSize * accumulatedCount);
663                 const auto countOffset = static_cast<VkDeviceSize>(sizeof(uint32_t) * countIdx);
664 
665                 vkd.cmdPushConstants(cmdBuffer, layout, VK_SHADER_STAGE_MESH_BIT_EXT, 0u, pcSize, &accumulatedCount);
666                 vkd.cmdDrawMeshTasksIndirectCountEXT(cmdBuffer, m_indirectBuffer->get(), offset,
667                                                      m_indirectCountBuffer->get(), countOffset, blockSize * 2u,
668                                                      cmdSize);
669                 accumulatedCount += blockSize;
670             }
671         }
672     }
673     else
674     {
675         DE_ASSERT(false);
676     }
677 }
678 
beginFirstQueries(const VkCommandBuffer cmdBuffer,const std::vector<VkQueryPool> & queryPools) const679 void MeshQueryInstance::beginFirstQueries(const VkCommandBuffer cmdBuffer,
680                                           const std::vector<VkQueryPool> &queryPools) const
681 {
682     const auto &vkd = m_context.getDeviceInterface();
683     for (const auto &pool : queryPools)
684         vkd.cmdBeginQuery(cmdBuffer, pool, 0u, 0u);
685 }
686 
endFirstQueries(const VkCommandBuffer cmdBuffer,const std::vector<VkQueryPool> & queryPools) const687 void MeshQueryInstance::endFirstQueries(const VkCommandBuffer cmdBuffer,
688                                         const std::vector<VkQueryPool> &queryPools) const
689 {
690     const auto &vkd = m_context.getDeviceInterface();
691     for (const auto &pool : queryPools)
692         vkd.cmdEndQuery(cmdBuffer, pool, 0u);
693 }
694 
resetFirstQueries(const VkCommandBuffer cmdBuffer,const std::vector<VkQueryPool> & queryPools,const uint32_t queryCount) const695 void MeshQueryInstance::resetFirstQueries(const VkCommandBuffer cmdBuffer, const std::vector<VkQueryPool> &queryPools,
696                                           const uint32_t queryCount) const
697 {
698     const auto &vkd = m_context.getDeviceInterface();
699     for (const auto &pool : queryPools)
700         vkd.cmdResetQueryPool(cmdBuffer, pool, 0u, queryCount);
701 }
702 
submitCommands(const VkCommandBuffer cmdBuffer) const703 void MeshQueryInstance::submitCommands(const VkCommandBuffer cmdBuffer) const
704 {
705     const auto &vkd  = m_context.getDeviceInterface();
706     const auto queue = m_context.getUniversalQueue();
707 
708     const VkSubmitInfo submitInfo = {
709         VK_STRUCTURE_TYPE_SUBMIT_INFO, // VkStructureType sType;
710         nullptr,                       // const void* pNext;
711         0u,                            // uint32_t waitSemaphoreCount;
712         nullptr,                       // const VkSemaphore* pWaitSemaphores;
713         nullptr,                       // const VkPipelineStageFlags* pWaitDstStageMask;
714         1u,                            // uint32_t commandBufferCount;
715         &cmdBuffer,                    // const VkCommandBuffer* pCommandBuffers;
716         0u,                            // uint32_t signalSemaphoreCount;
717         nullptr,                       // const VkSemaphore* pSignalSemaphores;
718     };
719 
720     VK_CHECK(vkd.queueSubmit(queue, 1u, &submitInfo, m_fence.get()));
721 }
722 
waitForFence(void) const723 void MeshQueryInstance::waitForFence(void) const
724 {
725     const auto &vkd   = m_context.getDeviceInterface();
726     const auto device = m_context.getDevice();
727 
728     VK_CHECK(vkd.waitForFences(device, 1u, &m_fence.get(), VK_TRUE, ~0ull));
729 }
730 
731 // Read query item from memory. Always returns uint64_t for convenience. Advances pointer to the next item.
readFromPtrAndAdvance(uint8_t ** const ptr,VkDeviceSize itemSize)732 uint64_t readFromPtrAndAdvance(uint8_t **const ptr, VkDeviceSize itemSize)
733 {
734     const auto itemSizeSz = static_cast<size_t>(itemSize);
735     uint64_t result       = std::numeric_limits<uint64_t>::max();
736 
737     if (itemSize == k64sz)
738     {
739         deMemcpy(&result, *ptr, itemSizeSz);
740     }
741     else if (itemSize == k32sz)
742     {
743         uint32_t aux = std::numeric_limits<uint32_t>::max();
744         deMemcpy(&aux, *ptr, itemSizeSz);
745         result = static_cast<uint64_t>(aux);
746     }
747     else
748         DE_ASSERT(false);
749 
750     *ptr += itemSizeSz;
751     return result;
752 }
753 
754 // General procedure to verify correctness of the availability bit, which does not depend on the exact query.
readAndVerifyAvailabilityBit(uint8_t ** const resultsPtr,VkDeviceSize itemSize,const TestParams & params,const std::string & queryName)755 void readAndVerifyAvailabilityBit(uint8_t **const resultsPtr, VkDeviceSize itemSize, const TestParams &params,
756                                   const std::string &queryName)
757 {
758     const uint64_t availabilityBitVal = readFromPtrAndAdvance(resultsPtr, itemSize);
759 
760     if (params.resetType == ResetCase::BEFORE_ACCESS)
761     {
762         if (availabilityBitVal)
763         {
764             std::ostringstream msg;
765             msg << queryName << " availability bit expected to be zero due to reset before access, but found "
766                 << availabilityBitVal;
767             TCU_FAIL(msg.str());
768         }
769     }
770     else if (params.waitBit)
771     {
772         if (!availabilityBitVal)
773         {
774             std::ostringstream msg;
775             msg << queryName << " availability expected to be true due to wait bit and not previous reset, but found "
776                 << availabilityBitVal;
777             TCU_FAIL(msg.str());
778         }
779     }
780 }
781 
782 // Verifies a query counter has the right value given the test parameters.
783 // - readVal is the reported counter value.
784 // - expectedMinVal and expectedMaxVal are the known right counts under "normal" circumstances.
785 // - The actual range of valid values will be adjusted depending on the test parameters (wait bit, reset, etc).
verifyQueryCounter(uint64_t readVal,uint64_t expectedMinVal,uint64_t expectedMaxVal,const TestParams & params,const std::string & queryName)786 void verifyQueryCounter(uint64_t readVal, uint64_t expectedMinVal, uint64_t expectedMaxVal, const TestParams &params,
787                         const std::string &queryName)
788 {
789     uint64_t minVal = expectedMinVal;
790     uint64_t maxVal = expectedMaxVal;
791 
792     // Resetting a query via vkCmdResetQueryPool or vkResetQueryPool sets the status to unavailable and makes the numerical results undefined.
793     const bool wasReset = (params.resetType == ResetCase::BEFORE_ACCESS);
794 
795     if (!wasReset)
796     {
797         if (!params.waitBit)
798             minVal = 0ull;
799 
800         if (!de::inRange(readVal, minVal, maxVal))
801         {
802             std::ostringstream msg;
803             msg << queryName << " not in expected range: " << readVal << " out of [" << minVal << ", " << maxVal << "]";
804             TCU_FAIL(msg.str());
805         }
806     }
807 }
808 
makeCustomRenderPass(const DeviceInterface & vkd,VkDevice device,uint32_t layerCount,VkFormat format)809 Move<VkRenderPass> MeshQueryInstance::makeCustomRenderPass(const DeviceInterface &vkd, VkDevice device,
810                                                            uint32_t layerCount, VkFormat format)
811 {
812     DE_ASSERT(layerCount > 0u);
813 
814     const VkAttachmentDescription colorAttachmentDescription = {
815         0u,                                       // VkAttachmentDescriptionFlags    flags
816         format,                                   // VkFormat                        format
817         VK_SAMPLE_COUNT_1_BIT,                    // VkSampleCountFlagBits           samples
818         VK_ATTACHMENT_LOAD_OP_CLEAR,              // VkAttachmentLoadOp              loadOp
819         VK_ATTACHMENT_STORE_OP_STORE,             // VkAttachmentStoreOp             storeOp
820         VK_ATTACHMENT_LOAD_OP_DONT_CARE,          // VkAttachmentLoadOp              stencilLoadOp
821         VK_ATTACHMENT_STORE_OP_DONT_CARE,         // VkAttachmentStoreOp             stencilStoreOp
822         VK_IMAGE_LAYOUT_UNDEFINED,                // VkImageLayout                   initialLayout
823         VK_IMAGE_LAYOUT_COLOR_ATTACHMENT_OPTIMAL, // VkImageLayout                   finalLayout
824     };
825 
826     const VkAttachmentReference colorAttachmentRef =
827         makeAttachmentReference(0u, VK_IMAGE_LAYOUT_COLOR_ATTACHMENT_OPTIMAL);
828 
829     const VkSubpassDescription subpassDescription = {
830         0u,                              // VkSubpassDescriptionFlags       flags
831         VK_PIPELINE_BIND_POINT_GRAPHICS, // VkPipelineBindPoint             pipelineBindPoint
832         0u,                              // uint32_t                        inputAttachmentCount
833         nullptr,                         // const VkAttachmentReference*    pInputAttachments
834         1u,                              // uint32_t                        colorAttachmentCount
835         &colorAttachmentRef,             // const VkAttachmentReference*    pColorAttachments
836         nullptr,                         // const VkAttachmentReference*    pResolveAttachments
837         nullptr,                         // const VkAttachmentReference*    pDepthStencilAttachment
838         0u,                              // uint32_t                        preserveAttachmentCount
839         nullptr                          // const uint32_t*                 pPreserveAttachments
840     };
841 
842     const uint32_t viewMask                                   = m_params->multiView ? ((1u << layerCount) - 1u) : 0u;
843     const uint32_t correlationMaskCount                       = m_params->multiView ? 1u : 0u;
844     const VkRenderPassMultiviewCreateInfo multiviewCreateInfo = {
845         VK_STRUCTURE_TYPE_RENDER_PASS_MULTIVIEW_CREATE_INFO, // VkStructureType sType;
846         nullptr,                                             // const void* pNext;
847         1u,                                                  // uint32_t subpassCount;
848         &viewMask,                                           // const uint32_t* pViewMasks;
849         0u,                                                  // uint32_t dependencyCount;
850         nullptr,                                             // const int32_t* pViewOffsets;
851         correlationMaskCount,                                // uint32_t correlationMaskCount;
852         &viewMask,                                           // const uint32_t* pCorrelationMasks;
853     };
854 
855     const VkRenderPassCreateInfo renderPassInfo = {
856         VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO, // VkStructureType                   sType
857         &multiviewCreateInfo,                      // const void*                       pNext
858         0u,                                        // VkRenderPassCreateFlags           flags
859         1u,                                        // uint32_t                          attachmentCount
860         &colorAttachmentDescription,               // const VkAttachmentDescription*    pAttachments
861         1u,                                        // uint32_t                          subpassCount
862         &subpassDescription,                       // const VkSubpassDescription*       pSubpasses
863         0u,                                        // uint32_t                          dependencyCount
864         nullptr,                                   // const VkSubpassDependency*        pDependencies
865     };
866 
867     return createRenderPass(vkd, device, &renderPassInfo);
868 }
869 
iterate(void)870 tcu::TestStatus MeshQueryInstance::iterate(void)
871 {
872     const auto &vkd       = m_context.getDeviceInterface();
873     const auto device     = m_context.getDevice();
874     auto &alloc           = m_context.getDefaultAllocator();
875     const auto queue      = m_context.getUniversalQueue();
876     const auto queueIndex = m_context.getUniversalQueueFamilyIndex();
877 
878     const auto colorFormat    = VK_FORMAT_R8G8B8A8_UNORM;
879     const auto colorTcuFormat = mapVkFormat(colorFormat);
880     const auto imageHeight    = m_params->getImageHeight();
881     const auto colorExtent    = makeExtent3D(kImageWidth, std::max(imageHeight, 1u), 1u);
882     const auto viewCount      = m_params->getViewCount();
883     const tcu::IVec3 colorTcuExtent(static_cast<int>(colorExtent.width), static_cast<int>(colorExtent.height),
884                                     static_cast<int>(viewCount));
885     const auto colorUsage = (VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT | VK_IMAGE_USAGE_TRANSFER_SRC_BIT);
886     const tcu::Vec4 clearColor(0.0f, 0.0f, 0.0f, 1.0f);
887     const auto expectedPrims   = (imageHeight * kImageWidth);
888     const auto expectedTaskInv = (m_params->useTaskShader ? (imageHeight * kTaskLocalInvocations / 2u) : 0u);
889     const auto expectedMeshInv = imageHeight * kMeshLocalInvocations;
890     const auto imageViewType   = ((viewCount > 1u) ? VK_IMAGE_VIEW_TYPE_2D_ARRAY : VK_IMAGE_VIEW_TYPE_2D);
891 
892     // Color buffer.
893     const VkImageCreateInfo colorBufferCreateInfo = {
894         VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // VkStructureType sType;
895         nullptr,                             // const void* pNext;
896         0u,                                  // VkImageCreateFlags flags;
897         VK_IMAGE_TYPE_2D,                    // VkImageType imageType;
898         colorFormat,                         // VkFormat format;
899         colorExtent,                         // VkExtent3D extent;
900         1u,                                  // uint32_t mipLevels;
901         viewCount,                           // uint32_t arrayLayers;
902         VK_SAMPLE_COUNT_1_BIT,               // VkSampleCountFlagBits samples;
903         VK_IMAGE_TILING_OPTIMAL,             // VkImageTiling tiling;
904         colorUsage,                          // VkImageUsageFlags usage;
905         VK_SHARING_MODE_EXCLUSIVE,           // VkSharingMode sharingMode;
906         0u,                                  // uint32_t queueFamilyIndexCount;
907         nullptr,                             // const uint32_t* pQueueFamilyIndices;
908         VK_IMAGE_LAYOUT_UNDEFINED,           // VkImageLayout initialLayout;
909     };
910 
911     const ImageWithMemory colorBuffer(vkd, device, alloc, colorBufferCreateInfo, MemoryRequirement::Any);
912     const auto colorSRR  = makeImageSubresourceRange(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 1u, 0u, viewCount);
913     const auto colorSRL  = makeImageSubresourceLayers(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 0u, viewCount);
914     const auto colorView = makeImageView(vkd, device, colorBuffer.get(), imageViewType, colorFormat, colorSRR);
915 
916     // Verification buffer.
917     DE_ASSERT(colorExtent.depth == 1u);
918     const VkDeviceSize verifBufferSize = colorExtent.width * colorExtent.height * viewCount *
919                                          static_cast<VkDeviceSize>(tcu::getPixelSize(colorTcuFormat));
920     const auto verifBufferCreateInfo = makeBufferCreateInfo(verifBufferSize, VK_BUFFER_USAGE_TRANSFER_DST_BIT);
921     const BufferWithMemory verifBuffer(vkd, device, alloc, verifBufferCreateInfo, MemoryRequirement::HostVisible);
922 
923     // Shader modules.
924     const auto &binaries = m_context.getBinaryCollection();
925     const auto taskModule =
926         (binaries.contains("task") ? createShaderModule(vkd, device, binaries.get("task")) : Move<VkShaderModule>());
927     const auto meshModule = createShaderModule(vkd, device, binaries.get("mesh"));
928     const auto fragModule = createShaderModule(vkd, device, binaries.get("frag"));
929 
930     // Pipeline layout.
931     const auto pcSize         = static_cast<uint32_t>(sizeof(uint32_t));
932     const auto pcRange        = makePushConstantRange(VK_SHADER_STAGE_MESH_BIT_EXT, 0u, pcSize);
933     const auto pipelineLayout = makePipelineLayout(vkd, device, DE_NULL, &pcRange);
934 
935     // Render pass, framebuffer, viewports, scissors.
936     const auto renderPass = makeCustomRenderPass(vkd, device, viewCount, colorFormat);
937     const auto framebuffer =
938         makeFramebuffer(vkd, device, renderPass.get(), colorView.get(), colorExtent.width, colorExtent.height);
939 
940     const std::vector<VkViewport> viewports(1u, makeViewport(colorExtent));
941     const std::vector<VkRect2D> scissors(1u, makeRect2D(colorExtent));
942 
943     const auto pipeline = makeGraphicsPipeline(vkd, device, pipelineLayout.get(), taskModule.get(), meshModule.get(),
944                                                fragModule.get(), renderPass.get(), viewports, scissors);
945 
946     // Command pool and buffers.
947     const auto cmdPool        = makeCommandPool(vkd, device, queueIndex);
948     const auto cmdBufferPtr   = allocateCommandBuffer(vkd, device, cmdPool.get(), VK_COMMAND_BUFFER_LEVEL_PRIMARY);
949     const auto resetCmdBuffer = allocateCommandBuffer(vkd, device, cmdPool.get(), VK_COMMAND_BUFFER_LEVEL_PRIMARY);
950     const auto cmdBuffer      = cmdBufferPtr.get();
951     const auto rawPipeline    = pipeline.get();
952     const auto rawPipeLayout  = pipelineLayout.get();
953 
954     Move<VkCommandBuffer> secCmdBufferPtr;
955     VkCommandBuffer secCmdBuffer = DE_NULL;
956 
957     if (m_params->useSecondary)
958     {
959         secCmdBufferPtr = allocateCommandBuffer(vkd, device, cmdPool.get(), VK_COMMAND_BUFFER_LEVEL_SECONDARY);
960         secCmdBuffer    = secCmdBufferPtr.get();
961     }
962 
963     // Create the query pools that we need.
964     Move<VkQueryPool> primitivesQueryPool;
965     Move<VkQueryPool> statsQueryPool;
966 
967     const bool hasPrimitivesQuery = m_params->hasPrimitivesQuery();
968     const bool hasMeshInvStat     = m_params->hasMeshInvStat();
969     const bool hasTaskInvStat     = m_params->hasTaskInvStat();
970     const bool hasStatsQuery      = (hasMeshInvStat || hasTaskInvStat);
971 
972     std::vector<VkQueryPool> allQueryPools;
973 
974     if (hasPrimitivesQuery)
975     {
976         const VkQueryPoolCreateInfo queryPoolCreateInfo = {
977             VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO,    // VkStructureType sType;
978             nullptr,                                     // const void* pNext;
979             0u,                                          // VkQueryPoolCreateFlags flags;
980             VK_QUERY_TYPE_MESH_PRIMITIVES_GENERATED_EXT, // VkQueryType queryType;
981             viewCount,                                   // uint32_t queryCount;
982             0u,                                          // VkQueryPipelineStatisticFlags pipelineStatistics;
983         };
984         primitivesQueryPool = createQueryPool(vkd, device, &queryPoolCreateInfo);
985         allQueryPools.push_back(primitivesQueryPool.get());
986     }
987 
988     const VkQueryPipelineStatisticFlags statQueryFlags =
989         ((hasMeshInvStat ? VK_QUERY_PIPELINE_STATISTIC_MESH_SHADER_INVOCATIONS_BIT_EXT : 0) |
990          (hasTaskInvStat ? VK_QUERY_PIPELINE_STATISTIC_TASK_SHADER_INVOCATIONS_BIT_EXT : 0));
991 
992     if (hasStatsQuery)
993     {
994         const VkQueryPoolCreateInfo queryPoolCreateInfo = {
995             VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO, // VkStructureType sType;
996             nullptr,                                  // const void* pNext;
997             0u,                                       // VkQueryPoolCreateFlags flags;
998             VK_QUERY_TYPE_PIPELINE_STATISTICS,        // VkQueryType queryType;
999             viewCount,                                // uint32_t queryCount;
1000             statQueryFlags,                           // VkQueryPipelineStatisticFlags pipelineStatistics;
1001         };
1002         statsQueryPool = createQueryPool(vkd, device, &queryPoolCreateInfo);
1003         allQueryPools.push_back(statsQueryPool.get());
1004     }
1005 
1006     // Some query result parameters.
1007     const auto querySizesAndOffsets = m_params->getQuerySizesAndOffsets();
1008     const size_t maxResultSize      = k64sz * 10ull; // 10 items at most: (prim+avail+task+mesh+avail)*2.
1009     const auto statsQueryOffsetSz   = static_cast<size_t>(querySizesAndOffsets.statsQueryOffset);
1010 
1011     // Create output buffer for the queries.
1012     BufferWithMemoryPtr queryResultsBuffer;
1013     if (m_params->access == AccessMethod::COPY)
1014     {
1015         const auto queryResultsBufferInfo =
1016             makeBufferCreateInfo(static_cast<VkDeviceSize>(maxResultSize), VK_BUFFER_USAGE_TRANSFER_DST_BIT);
1017         queryResultsBuffer = BufferWithMemoryPtr(
1018             new BufferWithMemory(vkd, device, alloc, queryResultsBufferInfo, MemoryRequirement::HostVisible));
1019     }
1020     std::vector<uint8_t> queryResultsHostVec(maxResultSize, 0);
1021 
1022     const auto statsDataHostVecPtr = queryResultsHostVec.data() + statsQueryOffsetSz;
1023     const auto statsRemainingSize  = maxResultSize - statsQueryOffsetSz;
1024 
1025     // Result flags when obtaining query results.
1026     const auto queryResultFlags = m_params->getQueryResultFlags();
1027 
1028     // Reset queries before use.
1029     // Queries will be reset in a separate command buffer to make sure they are always properly reset before use.
1030     // We could do this with VK_EXT_host_query_reset too.
1031     {
1032         beginCommandBuffer(vkd, resetCmdBuffer.get());
1033         resetFirstQueries(resetCmdBuffer.get(), allQueryPools, viewCount);
1034         endCommandBuffer(vkd, resetCmdBuffer.get());
1035         submitCommandsAndWait(vkd, device, queue, resetCmdBuffer.get());
1036     }
1037 
1038     // Command recording.
1039     beginCommandBuffer(vkd, cmdBuffer);
1040 
1041     if (m_params->useSecondary)
1042     {
1043         const VkCommandBufferInheritanceInfo inheritanceInfo = {
1044             VK_STRUCTURE_TYPE_COMMAND_BUFFER_INHERITANCE_INFO, // VkStructureType sType;
1045             nullptr,                                           // const void* pNext;
1046             renderPass.get(),                                  // VkRenderPass renderPass;
1047             0u,                                                // uint32_t subpass;
1048             framebuffer.get(),                                 // VkFramebuffer framebuffer;
1049             VK_FALSE,                                          // VkBool32 occlusionQueryEnable;
1050             0u,                                                // VkQueryControlFlags queryFlags;
1051             (m_params->areQueriesInherited() ? statQueryFlags :
1052                                                0u), // VkQueryPipelineStatisticFlags pipelineStatistics;
1053         };
1054 
1055         const auto secCmdBufferFlags =
1056             (VK_COMMAND_BUFFER_USAGE_RENDER_PASS_CONTINUE_BIT | VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT);
1057 
1058         const VkCommandBufferBeginInfo secBeginInfo = {
1059             VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO, // VkStructureType sType;
1060             nullptr,                                     // const void* pNext;
1061             secCmdBufferFlags,                           // VkCommandBufferUsageFlags flags;
1062             &inheritanceInfo,                            // const VkCommandBufferInheritanceInfo* pInheritanceInfo;
1063         };
1064 
1065         VK_CHECK(vkd.beginCommandBuffer(secCmdBuffer, &secBeginInfo));
1066     }
1067 
1068     const auto subpassContents =
1069         (m_params->useSecondary ? VK_SUBPASS_CONTENTS_SECONDARY_COMMAND_BUFFERS : VK_SUBPASS_CONTENTS_INLINE);
1070 
1071     // 4 cases:
1072     //
1073     // * Only primary, inside render pass
1074     // * Only primary, outside render pass
1075     // * Primary and secondary, inside render pass (query in secondary)
1076     // * Primary and secondary, outside render pass (query inheritance)
1077 
1078     if (!m_params->useSecondary)
1079     {
1080         if (m_params->insideRenderPass)
1081         {
1082             beginRenderPass(vkd, cmdBuffer, renderPass.get(), framebuffer.get(), scissors.at(0), clearColor,
1083                             subpassContents);
1084             beginFirstQueries(cmdBuffer, allQueryPools);
1085             recordDraws(cmdBuffer, rawPipeline, rawPipeLayout);
1086             endFirstQueries(cmdBuffer, allQueryPools);
1087             endRenderPass(vkd, cmdBuffer);
1088         }
1089         else
1090         {
1091             DE_ASSERT(!m_params->multiView);
1092             beginFirstQueries(cmdBuffer, allQueryPools);
1093             beginRenderPass(vkd, cmdBuffer, renderPass.get(), framebuffer.get(), scissors.at(0), clearColor,
1094                             subpassContents);
1095             recordDraws(cmdBuffer, rawPipeline, rawPipeLayout);
1096             endRenderPass(vkd, cmdBuffer);
1097             endFirstQueries(cmdBuffer, allQueryPools);
1098         }
1099     }
1100     else
1101     {
1102         if (m_params->insideRenderPass) // Queries in secondary command buffer.
1103         {
1104             beginRenderPass(vkd, cmdBuffer, renderPass.get(), framebuffer.get(), scissors.at(0), clearColor,
1105                             subpassContents);
1106             beginFirstQueries(secCmdBuffer, allQueryPools);
1107             recordDraws(secCmdBuffer, rawPipeline, rawPipeLayout);
1108             endFirstQueries(secCmdBuffer, allQueryPools);
1109             endCommandBuffer(vkd, secCmdBuffer);
1110             vkd.cmdExecuteCommands(cmdBuffer, 1u, &secCmdBuffer);
1111             endRenderPass(vkd, cmdBuffer);
1112         }
1113         else // Inherited queries case.
1114         {
1115             DE_ASSERT(!m_params->multiView);
1116             beginFirstQueries(cmdBuffer, allQueryPools);
1117             beginRenderPass(vkd, cmdBuffer, renderPass.get(), framebuffer.get(), scissors.at(0), clearColor,
1118                             subpassContents);
1119             recordDraws(secCmdBuffer, rawPipeline, rawPipeLayout);
1120             endCommandBuffer(vkd, secCmdBuffer);
1121             vkd.cmdExecuteCommands(cmdBuffer, 1u, &secCmdBuffer);
1122             endRenderPass(vkd, cmdBuffer);
1123             endFirstQueries(cmdBuffer, allQueryPools);
1124         }
1125     }
1126 
1127     // Render to copy barrier.
1128     {
1129         const auto preCopyImgBarrier = makeImageMemoryBarrier(
1130             VK_ACCESS_COLOR_ATTACHMENT_WRITE_BIT, VK_ACCESS_TRANSFER_READ_BIT, VK_IMAGE_LAYOUT_COLOR_ATTACHMENT_OPTIMAL,
1131             VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL, colorBuffer.get(), colorSRR);
1132         cmdPipelineImageMemoryBarrier(vkd, cmdBuffer, VK_PIPELINE_STAGE_COLOR_ATTACHMENT_OUTPUT_BIT,
1133                                       VK_PIPELINE_STAGE_TRANSFER_BIT, &preCopyImgBarrier);
1134     }
1135 
1136     if (m_params->resetType == ResetCase::BEFORE_ACCESS)
1137         resetFirstQueries(cmdBuffer, allQueryPools, viewCount);
1138 
1139     if (m_params->access == AccessMethod::COPY)
1140     {
1141         if (hasPrimitivesQuery)
1142             vkd.cmdCopyQueryPoolResults(cmdBuffer, primitivesQueryPool.get(), 0u, viewCount, queryResultsBuffer->get(),
1143                                         0ull, querySizesAndOffsets.primitivesQuerySize, queryResultFlags);
1144 
1145         if (hasStatsQuery)
1146             vkd.cmdCopyQueryPoolResults(cmdBuffer, statsQueryPool.get(), 0u, viewCount, queryResultsBuffer->get(),
1147                                         querySizesAndOffsets.statsQueryOffset, querySizesAndOffsets.statsQuerySize,
1148                                         queryResultFlags);
1149     }
1150 
1151     if (m_params->resetType == ResetCase::AFTER_ACCESS)
1152         resetFirstQueries(cmdBuffer, allQueryPools, viewCount);
1153 
1154     // Copy color attachment to verification buffer.
1155     {
1156         const auto copyRegion = makeBufferImageCopy(colorExtent, colorSRL);
1157         vkd.cmdCopyImageToBuffer(cmdBuffer, colorBuffer.get(), VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL, verifBuffer.get(),
1158                                  1u, &copyRegion);
1159     }
1160 
1161     // This barrier applies to both the color verification buffer and the queries if they were copied.
1162     const auto postCopyBarrier = makeMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
1163     cmdPipelineMemoryBarrier(vkd, cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_HOST_BIT,
1164                              &postCopyBarrier);
1165 
1166     endCommandBuffer(vkd, cmdBuffer);
1167     submitCommands(cmdBuffer);
1168 
1169     // When using GET, obtain results before actually waiting for the fence if possible. This way it's more interesting for cases
1170     // that do not use the wait bit.
1171     if (m_params->access == AccessMethod::GET)
1172     {
1173         // When resetting queries before access, we need to make sure the reset operation has really taken place.
1174         if (m_params->resetType == ResetCase::BEFORE_ACCESS)
1175             waitForFence();
1176 
1177         const bool allowNotReady = !m_params->waitBit;
1178 
1179         if (hasPrimitivesQuery)
1180         {
1181             const auto res = vkd.getQueryPoolResults(device, primitivesQueryPool.get(), 0u, viewCount,
1182                                                      de::dataSize(queryResultsHostVec), queryResultsHostVec.data(),
1183                                                      querySizesAndOffsets.primitivesQuerySize, queryResultFlags);
1184             checkGetQueryRes(res, allowNotReady);
1185         }
1186 
1187         if (hasStatsQuery)
1188         {
1189             const auto res =
1190                 vkd.getQueryPoolResults(device, statsQueryPool.get(), 0u, viewCount, statsRemainingSize,
1191                                         statsDataHostVecPtr, querySizesAndOffsets.statsQuerySize, queryResultFlags);
1192             checkGetQueryRes(res, allowNotReady);
1193         }
1194     }
1195 
1196     waitForFence();
1197 
1198     // Verify color buffer.
1199     {
1200         auto &log              = m_context.getTestContext().getLog();
1201         auto &verifBufferAlloc = verifBuffer.getAllocation();
1202         void *verifBufferData  = verifBufferAlloc.getHostPtr();
1203 
1204         invalidateAlloc(vkd, device, verifBufferAlloc);
1205 
1206         tcu::ConstPixelBufferAccess verifAccess(colorTcuFormat, colorTcuExtent, verifBufferData);
1207         const tcu::Vec4 threshold(0.0f, 0.0f, 0.0f, 0.0f); // Results should be exact.
1208 
1209         for (int layer = 0; layer < colorTcuExtent.z(); ++layer)
1210         {
1211             // This should match the fragment shader.
1212             const auto green = ((layer > 0) ? 1.0f : 0.0f);
1213             const auto referenceColor =
1214                 ((m_params->getTotalDrawCount() > 0u) ? tcu::Vec4(0.0f, green, 1.0f, 1.0f) : clearColor);
1215             const auto layerAccess =
1216                 tcu::getSubregion(verifAccess, 0, 0, layer, colorTcuExtent.x(), colorTcuExtent.y(), 1);
1217 
1218             if (!tcu::floatThresholdCompare(log, "Color Result", "", referenceColor, layerAccess, threshold,
1219                                             tcu::COMPARE_LOG_ON_ERROR))
1220             {
1221                 std::ostringstream msg;
1222                 msg << "Color target mismatch at layer " << layer << "; check log for details";
1223                 TCU_FAIL(msg.str());
1224             }
1225         }
1226     }
1227 
1228     // Verify query results.
1229     {
1230         const auto itemSize = querySizesAndOffsets.queryItemSize;
1231         uint8_t *resultsPtr = nullptr;
1232 
1233         if (m_params->access == AccessMethod::COPY)
1234         {
1235             auto &queryResultsBufferAlloc = queryResultsBuffer->getAllocation();
1236             void *queryResultsBufferData  = queryResultsBufferAlloc.getHostPtr();
1237             invalidateAlloc(vkd, device, queryResultsBufferAlloc);
1238 
1239             resultsPtr = reinterpret_cast<uint8_t *>(queryResultsBufferData);
1240         }
1241         else if (m_params->access == AccessMethod::GET)
1242         {
1243             resultsPtr = queryResultsHostVec.data();
1244         }
1245 
1246         if (hasPrimitivesQuery)
1247         {
1248             const std::string queryGroupName = "Primitive count";
1249             uint64_t totalPrimitiveCount     = 0ull;
1250 
1251             for (uint32_t viewIndex = 0u; viewIndex < viewCount; ++viewIndex)
1252             {
1253                 const std::string queryName   = queryGroupName + " for view " + std::to_string(viewIndex);
1254                 const uint64_t primitiveCount = readFromPtrAndAdvance(&resultsPtr, itemSize);
1255 
1256                 totalPrimitiveCount += primitiveCount;
1257 
1258                 if (m_params->availabilityBit)
1259                     readAndVerifyAvailabilityBit(&resultsPtr, itemSize, *m_params, queryName);
1260             }
1261 
1262             verifyQueryCounter(totalPrimitiveCount, expectedPrims, expectedPrims * viewCount, *m_params,
1263                                queryGroupName);
1264         }
1265 
1266         if (hasStatsQuery)
1267         {
1268             const std::string queryGroupName = "Stats query";
1269             uint64_t totalTaskInvs           = 0ull;
1270             uint64_t totalMeshInvs           = 0ull;
1271 
1272             for (uint32_t viewIndex = 0u; viewIndex < viewCount; ++viewIndex)
1273             {
1274                 if (hasTaskInvStat)
1275                 {
1276                     const uint64_t taskInvs = readFromPtrAndAdvance(&resultsPtr, itemSize);
1277                     totalTaskInvs += taskInvs;
1278                 }
1279 
1280                 if (hasMeshInvStat)
1281                 {
1282                     const uint64_t meshInvs = readFromPtrAndAdvance(&resultsPtr, itemSize);
1283                     totalMeshInvs += meshInvs;
1284                 }
1285 
1286                 if (m_params->availabilityBit)
1287                 {
1288                     const std::string queryName = queryGroupName + " for view " + std::to_string(viewIndex);
1289                     readAndVerifyAvailabilityBit(&resultsPtr, itemSize, *m_params, queryGroupName);
1290                 }
1291             }
1292 
1293             if (hasTaskInvStat)
1294                 verifyQueryCounter(totalTaskInvs, expectedTaskInv, expectedTaskInv * viewCount, *m_params,
1295                                    "Task invocations");
1296 
1297             if (hasMeshInvStat)
1298                 verifyQueryCounter(totalMeshInvs, expectedMeshInv, expectedMeshInv * viewCount, *m_params,
1299                                    "Mesh invocations");
1300         }
1301     }
1302 
1303     if (m_params->resetType == ResetCase::NONE_WITH_HOST)
1304     {
1305         // We'll reset the different queries that we used before and we'll retrieve results again with GET, forcing availability bit
1306         // and no wait bit. We'll verify availability bits are zero.
1307         uint8_t *resultsPtr = queryResultsHostVec.data();
1308 
1309         // New parameters, based on the existing ones, that match the behavior we expect below.
1310         TestParams postResetParams      = *m_params;
1311         postResetParams.availabilityBit = true;
1312         postResetParams.waitBit         = false;
1313         postResetParams.resetType       = ResetCase::BEFORE_ACCESS;
1314 
1315         const auto postResetFlags         = postResetParams.getQueryResultFlags();
1316         const auto newSizesAndOffsets     = postResetParams.getQuerySizesAndOffsets();
1317         const auto newStatsQueryOffsetSz  = static_cast<size_t>(newSizesAndOffsets.statsQueryOffset);
1318         const auto newStatsDataHostVecPtr = queryResultsHostVec.data() + newStatsQueryOffsetSz;
1319         const auto newStatsRemainingSize  = maxResultSize - newStatsQueryOffsetSz;
1320         const auto itemSize               = newSizesAndOffsets.queryItemSize;
1321 
1322         if (hasPrimitivesQuery)
1323         {
1324             vkd.resetQueryPool(device, primitivesQueryPool.get(), 0u, viewCount);
1325             const auto res = vkd.getQueryPoolResults(device, primitivesQueryPool.get(), 0u, viewCount,
1326                                                      de::dataSize(queryResultsHostVec), queryResultsHostVec.data(),
1327                                                      newSizesAndOffsets.primitivesQuerySize, postResetFlags);
1328             checkGetQueryRes(res, true /*allowNotReady*/);
1329         }
1330 
1331         if (hasStatsQuery)
1332         {
1333             vkd.resetQueryPool(device, statsQueryPool.get(), 0u, viewCount);
1334             const auto res =
1335                 vkd.getQueryPoolResults(device, statsQueryPool.get(), 0u, viewCount, newStatsRemainingSize,
1336                                         newStatsDataHostVecPtr, newSizesAndOffsets.statsQuerySize, postResetFlags);
1337             checkGetQueryRes(res, true /*allowNotReady*/);
1338         }
1339 
1340         if (hasPrimitivesQuery)
1341         {
1342             for (uint32_t viewIndex = 0u; viewIndex < viewCount; ++viewIndex)
1343             {
1344                 const std::string queryName   = "Post-reset primitive count for view " + std::to_string(viewIndex);
1345                 const uint64_t primitiveCount = readFromPtrAndAdvance(&resultsPtr, itemSize);
1346 
1347                 // Resetting a query without beginning it again makes numerical results undefined.
1348                 //verifyQueryCounter(primitiveCount, 0ull, postResetParams, queryName);
1349                 DE_UNREF(primitiveCount);
1350                 readAndVerifyAvailabilityBit(&resultsPtr, itemSize, postResetParams, queryName);
1351             }
1352         }
1353 
1354         if (hasStatsQuery)
1355         {
1356             for (uint32_t viewIndex = 0u; viewIndex < viewCount; ++viewIndex)
1357             {
1358                 if (hasTaskInvStat)
1359                 {
1360                     const uint64_t taskInvs = readFromPtrAndAdvance(&resultsPtr, itemSize);
1361                     // Resetting a query without beginning it again makes numerical results undefined.
1362                     //verifyQueryCounter(taskInvs, 0ull, postResetParams, "Post-reset task invocations");
1363                     DE_UNREF(taskInvs);
1364                 }
1365 
1366                 if (hasMeshInvStat)
1367                 {
1368                     const uint64_t meshInvs = readFromPtrAndAdvance(&resultsPtr, itemSize);
1369                     // Resetting a query without beginning it again makes numerical results undefined.
1370                     //verifyQueryCounter(meshInvs, 0ull, postResetParams, "Post-reset mesh invocations");
1371                     DE_UNREF(meshInvs);
1372                 }
1373 
1374                 const std::string queryName = "Post-reset stats query for view " + std::to_string(viewIndex);
1375                 readAndVerifyAvailabilityBit(&resultsPtr, itemSize, postResetParams, queryName);
1376             }
1377         }
1378     }
1379 
1380     return tcu::TestStatus::pass("Pass");
1381 }
1382 
1383 using GroupPtr = de::MovePtr<tcu::TestCaseGroup>;
1384 
1385 } // namespace
1386 
createMeshShaderQueryTestsEXT(tcu::TestContext & testCtx)1387 tcu::TestCaseGroup *createMeshShaderQueryTestsEXT(tcu::TestContext &testCtx)
1388 {
1389     GroupPtr queryGroup(new tcu::TestCaseGroup(testCtx, "query"));
1390 
1391     const struct
1392     {
1393         std::vector<QueryType> queryTypes;
1394         const char *name;
1395     } queryCombinations[] = {
1396         {{}, "no_queries"},
1397         {{QueryType::PRIMITIVES}, "prim_query"},
1398         {{QueryType::TASK_INVOCATIONS}, "task_invs_query"},
1399         {{QueryType::MESH_INVOCATIONS}, "mesh_invs_query"},
1400         {{QueryType::TASK_INVOCATIONS, QueryType::MESH_INVOCATIONS}, "all_stats_query"},
1401         {{QueryType::PRIMITIVES, QueryType::TASK_INVOCATIONS, QueryType::MESH_INVOCATIONS}, "all_queries"},
1402     };
1403 
1404     const struct
1405     {
1406         DrawCallType drawCallType;
1407         const char *name;
1408     } drawCalls[] = {
1409         {DrawCallType::DIRECT, "draw"},
1410         {DrawCallType::INDIRECT, "indirect_draw"},
1411         {DrawCallType::INDIRECT_WITH_COUNT, "indirect_with_count_draw"},
1412     };
1413 
1414     const struct
1415     {
1416         std::vector<uint32_t> drawBlocks;
1417         const char *name;
1418     } blockCases[] = {
1419         {{}, "no_blocks"},
1420         {{10u}, "single_block"},
1421         {{10u, 20u, 30u}, "multiple_blocks"},
1422     };
1423 
1424     const struct
1425     {
1426         ResetCase resetCase;
1427         const char *name;
1428     } resetTypes[] = {
1429         {ResetCase::NONE, "no_reset"},
1430         {ResetCase::NONE_WITH_HOST, "host_reset"},
1431         {ResetCase::BEFORE_ACCESS, "reset_before"},
1432         {ResetCase::AFTER_ACCESS, "reset_after"},
1433     };
1434 
1435     const struct
1436     {
1437         AccessMethod accessMethod;
1438         const char *name;
1439     } accessMethods[] = {
1440         {AccessMethod::COPY, "copy"},
1441         {AccessMethod::GET, "get"},
1442     };
1443 
1444     const struct
1445     {
1446         GeometryType geometry;
1447         const char *name;
1448     } geometryCases[] = {
1449         {GeometryType::POINTS, "points"},
1450         {GeometryType::LINES, "lines"},
1451         {GeometryType::TRIANGLES, "triangles"},
1452     };
1453 
1454     const struct
1455     {
1456         bool use64Bits;
1457         const char *name;
1458     } resultSizes[] = {
1459         {false, "32bit"},
1460         {true, "64bit"},
1461     };
1462 
1463     const struct
1464     {
1465         bool availabilityFlag;
1466         const char *name;
1467     } availabilityCases[] = {
1468         {false, "no_availability"},
1469         {true, "with_availability"},
1470     };
1471 
1472     const struct
1473     {
1474         bool waitFlag;
1475         const char *name;
1476     } waitCases[] = {
1477         {false, "no_wait"},
1478         {true, "wait"},
1479     };
1480 
1481     const struct
1482     {
1483         bool taskShader;
1484         const char *name;
1485     } taskShaderCases[] = {
1486         {false, "mesh_only"},
1487         {true, "task_mesh"},
1488     };
1489 
1490     const struct
1491     {
1492         bool insideRenderPass;
1493         const char *name;
1494     } orderingCases[] = {
1495         {false, "include_rp"},
1496         {true, "inside_rp"},
1497     };
1498 
1499     const struct
1500     {
1501         bool multiView;
1502         const char *name;
1503     } multiViewCases[] = {
1504         {false, "single_view"},
1505         {true, "multi_view"},
1506     };
1507 
1508     const struct
1509     {
1510         bool useSecondary;
1511         const char *name;
1512     } cmdBufferTypes[] = {
1513         {false, "only_primary"},
1514         {true, "with_secondary"},
1515     };
1516 
1517     for (const auto &queryCombination : queryCombinations)
1518     {
1519         const bool noQueries = queryCombination.queryTypes.empty();
1520         const bool hasPrimitivesQuery =
1521             de::contains(queryCombination.queryTypes.begin(), queryCombination.queryTypes.end(), QueryType::PRIMITIVES);
1522 
1523         GroupPtr queryCombinationGroup(new tcu::TestCaseGroup(testCtx, queryCombination.name));
1524 
1525         for (const auto &geometryCase : geometryCases)
1526         {
1527             if (noQueries && geometryCase.geometry != GeometryType::LINES)
1528                 continue;
1529 
1530             const bool nonTriangles = (geometryCase.geometry != GeometryType::TRIANGLES);
1531 
1532             // For cases without primitive queries, skip non-triangle geometries.
1533             if (!hasPrimitivesQuery && !noQueries && nonTriangles)
1534                 continue;
1535 
1536             GroupPtr geometryCaseGroup(new tcu::TestCaseGroup(testCtx, geometryCase.name));
1537 
1538             for (const auto &resetType : resetTypes)
1539             {
1540                 if (noQueries && resetType.resetCase != ResetCase::NONE)
1541                     continue;
1542 
1543                 GroupPtr resetTypeGroup(new tcu::TestCaseGroup(testCtx, resetType.name));
1544 
1545                 for (const auto &accessMethod : accessMethods)
1546                 {
1547                     if (noQueries && accessMethod.accessMethod != AccessMethod::COPY)
1548                         continue;
1549 
1550                     // Get + reset after access is not a valid combination (queries will be accessed after submission).
1551                     if (accessMethod.accessMethod == AccessMethod::GET &&
1552                         resetType.resetCase == ResetCase::AFTER_ACCESS)
1553                         continue;
1554 
1555                     GroupPtr accessMethodGroup(new tcu::TestCaseGroup(testCtx, accessMethod.name));
1556 
1557                     for (const auto &waitCase : waitCases)
1558                     {
1559                         if (noQueries && waitCase.waitFlag)
1560                             continue;
1561 
1562                         // Wait and reset before access is not valid (the query would never finish).
1563                         if (resetType.resetCase == ResetCase::BEFORE_ACCESS && waitCase.waitFlag)
1564                             continue;
1565 
1566                         GroupPtr waitCaseGroup(new tcu::TestCaseGroup(testCtx, waitCase.name));
1567 
1568                         for (const auto &drawCall : drawCalls)
1569                         {
1570                             // Explicitly remove some combinations with non-triangles, just to reduce the number of tests.
1571                             if (drawCall.drawCallType != DrawCallType::DIRECT && nonTriangles)
1572                                 continue;
1573 
1574                             GroupPtr drawCallGroup(new tcu::TestCaseGroup(testCtx, drawCall.name));
1575 
1576                             for (const auto &resultSize : resultSizes)
1577                             {
1578                                 if (noQueries && resultSize.use64Bits)
1579                                     continue;
1580 
1581                                 // Explicitly remove some combinations with non-triangles, just to reduce the number of tests.
1582                                 if (resultSize.use64Bits && nonTriangles)
1583                                     continue;
1584 
1585                                 GroupPtr resultSizeGroup(new tcu::TestCaseGroup(testCtx, resultSize.name));
1586 
1587                                 for (const auto &availabilityCase : availabilityCases)
1588                                 {
1589                                     if (noQueries && availabilityCase.availabilityFlag)
1590                                         continue;
1591 
1592                                     // Explicitly remove some combinations with non-triangles, just to reduce the number of tests.
1593                                     if (availabilityCase.availabilityFlag && nonTriangles)
1594                                         continue;
1595 
1596                                     GroupPtr availabilityCaseGroup(
1597                                         new tcu::TestCaseGroup(testCtx, availabilityCase.name));
1598 
1599                                     for (const auto &blockCase : blockCases)
1600                                     {
1601                                         // Explicitly remove some combinations with non-triangles, just to reduce the number of tests.
1602                                         if (blockCase.drawBlocks.size() <= 1 && nonTriangles)
1603                                             continue;
1604 
1605                                         GroupPtr blockCaseGroup(new tcu::TestCaseGroup(testCtx, blockCase.name));
1606 
1607                                         for (const auto &taskShaderCase : taskShaderCases)
1608                                         {
1609                                             GroupPtr taskShaderCaseGroup(
1610                                                 new tcu::TestCaseGroup(testCtx, taskShaderCase.name));
1611 
1612                                             for (const auto &orderingCase : orderingCases)
1613                                             {
1614                                                 if (noQueries && !orderingCase.insideRenderPass)
1615                                                     continue;
1616 
1617                                                 GroupPtr orderingCaseGroup(
1618                                                     new tcu::TestCaseGroup(testCtx, orderingCase.name));
1619 
1620                                                 for (const auto &multiViewCase : multiViewCases)
1621                                                 {
1622                                                     if (multiViewCase.multiView && !orderingCase.insideRenderPass)
1623                                                         continue;
1624 
1625                                                     GroupPtr multiViewGroup(
1626                                                         new tcu::TestCaseGroup(testCtx, multiViewCase.name));
1627 
1628                                                     for (const auto &cmdBufferType : cmdBufferTypes)
1629                                                     {
1630                                                         TestParams params;
1631                                                         params.queryTypes       = queryCombination.queryTypes;
1632                                                         params.drawBlocks       = blockCase.drawBlocks;
1633                                                         params.drawCall         = drawCall.drawCallType;
1634                                                         params.geometry         = geometryCase.geometry;
1635                                                         params.resetType        = resetType.resetCase;
1636                                                         params.access           = accessMethod.accessMethod;
1637                                                         params.use64Bits        = resultSize.use64Bits;
1638                                                         params.availabilityBit  = availabilityCase.availabilityFlag;
1639                                                         params.waitBit          = waitCase.waitFlag;
1640                                                         params.useTaskShader    = taskShaderCase.taskShader;
1641                                                         params.insideRenderPass = orderingCase.insideRenderPass;
1642                                                         params.useSecondary     = cmdBufferType.useSecondary;
1643                                                         params.multiView        = multiViewCase.multiView;
1644 
1645                                                         // VUID-vkCmdExecuteCommands-commandBuffer-07594
1646                                                         if (params.areQueriesInherited() && params.hasPrimitivesQuery())
1647                                                             continue;
1648 
1649                                                         multiViewGroup->addChild(new MeshQueryCase(
1650                                                             testCtx, cmdBufferType.name, std::move(params)));
1651                                                     }
1652 
1653                                                     orderingCaseGroup->addChild(multiViewGroup.release());
1654                                                 }
1655 
1656                                                 taskShaderCaseGroup->addChild(orderingCaseGroup.release());
1657                                             }
1658 
1659                                             blockCaseGroup->addChild(taskShaderCaseGroup.release());
1660                                         }
1661 
1662                                         availabilityCaseGroup->addChild(blockCaseGroup.release());
1663                                     }
1664 
1665                                     resultSizeGroup->addChild(availabilityCaseGroup.release());
1666                                 }
1667 
1668                                 drawCallGroup->addChild(resultSizeGroup.release());
1669                             }
1670 
1671                             waitCaseGroup->addChild(drawCallGroup.release());
1672                         }
1673 
1674                         accessMethodGroup->addChild(waitCaseGroup.release());
1675                     }
1676 
1677                     resetTypeGroup->addChild(accessMethodGroup.release());
1678                 }
1679 
1680                 geometryCaseGroup->addChild(resetTypeGroup.release());
1681             }
1682 
1683             queryCombinationGroup->addChild(geometryCaseGroup.release());
1684         }
1685 
1686         queryGroup->addChild(queryCombinationGroup.release());
1687     }
1688 
1689     return queryGroup.release();
1690 }
1691 
1692 } // namespace MeshShader
1693 } // namespace vkt
1694