xref: /aosp_15_r20/external/deqp/external/vulkancts/modules/vulkan/mesh_shader/vktMeshShaderApiTestsEXT.cpp (revision 35238bce31c2a825756842865a792f8cf7f89930)
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2021 The Khronos Group Inc.
6  * Copyright (c) 2021 Valve Corporation.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  *//*!
21  * \file
22  * \brief Mesh Shader API Tests for VK_EXT_mesh_shader
23  *//*--------------------------------------------------------------------*/
24 
25 #include "vktMeshShaderApiTestsEXT.hpp"
26 #include "vktMeshShaderUtil.hpp"
27 #include "vktTestCase.hpp"
28 
29 #include "vkTypeUtil.hpp"
30 #include "vkImageWithMemory.hpp"
31 #include "vkBufferWithMemory.hpp"
32 #include "vkObjUtil.hpp"
33 #include "vkBuilderUtil.hpp"
34 #include "vkCmdUtil.hpp"
35 #include "vkImageUtil.hpp"
36 
37 #include "tcuMaybe.hpp"
38 #include "tcuTestLog.hpp"
39 #include "tcuImageCompare.hpp"
40 
41 #include "deRandom.hpp"
42 
43 #include <iostream>
44 #include <sstream>
45 #include <vector>
46 #include <algorithm>
47 #include <iterator>
48 #include <limits>
49 
50 namespace vkt
51 {
52 namespace MeshShader
53 {
54 
55 namespace
56 {
57 
58 using namespace vk;
59 
60 using GroupPtr            = de::MovePtr<tcu::TestCaseGroup>;
61 using ImageWithMemoryPtr  = de::MovePtr<ImageWithMemory>;
62 using BufferWithMemoryPtr = de::MovePtr<BufferWithMemory>;
63 
64 enum class DrawType
65 {
66     DRAW = 0,
67     DRAW_INDIRECT,
68     DRAW_INDIRECT_COUNT,
69 };
70 
operator <<(std::ostream & stream,DrawType drawType)71 std::ostream &operator<<(std::ostream &stream, DrawType drawType)
72 {
73     switch (drawType)
74     {
75     case DrawType::DRAW:
76         stream << "draw";
77         break;
78     case DrawType::DRAW_INDIRECT:
79         stream << "draw_indirect";
80         break;
81     case DrawType::DRAW_INDIRECT_COUNT:
82         stream << "draw_indirect_count";
83         break;
84     default:
85         DE_ASSERT(false);
86         break;
87     }
88     return stream;
89 }
90 
91 // This helps test the maxDrawCount rule for the DRAW_INDIRECT_COUNT case.
92 enum class IndirectCountLimitType
93 {
94     BUFFER_VALUE = 0, // The actual count will be given by the count buffer.
95     MAX_COUNT,        // The actual count will be given by the maxDrawCount argument passed to the draw command.
96 };
97 
98 struct IndirectArgs
99 {
100     uint32_t offset;
101     uint32_t stride;
102 };
103 
104 struct TestParams
105 {
106     DrawType drawType;
107     uint32_t seed;
108     uint32_t drawCount;                                    // Equivalent to taskCount or drawCount.
109     tcu::Maybe<IndirectArgs> indirectArgs;                 // Only used for DRAW_INDIRECT*.
110     tcu::Maybe<IndirectCountLimitType> indirectCountLimit; // Only used for DRAW_INDIRECT_COUNT.
111     tcu::Maybe<uint32_t> indirectCountOffset;              // Only used for DRAW_INDIRECT_COUNT.
112     bool useTask;
113     bool useSecondaryCmdBuffer;
114 };
115 
116 // The framebuffer will have a number of rows and 32 columns. Each mesh shader workgroup will generate geometry to fill a single
117 // framebuffer row, using a triangle list with 32 triangles of different colors, each covering a framebuffer pixel.
118 //
119 // Note: the total framebuffer rows is called "full" below (e.g. 64). When using a task shader to generate work, each workgroup will
120 // generate a single mesh workgroup using a push constant instead of a compile-time constant.
121 //
122 // When using DRAW, the task count will tell us how many rows of pixels will be filled in the framebuffer.
123 //
124 // When using indirect draws, the full framebuffer will always be drawn into by using multiple draw command structures, except in
125 // the case of drawCount==0. Each draw will spawn the needed number of tasks to fill the whole framebuffer. In addition, in order to
126 // make all argument structures different, the number of tasks in each draw count will be slightly different and assigned
127 // pseudorandomly.
128 //
129 // DRAW: taskCount=0, taskCount=1, taskCount=2, taskCount=half, taskCount=full
130 //
131 // DRAW_INDIRECT: drawCount=0, drawCount=1, drawCount=2, drawCount=half, drawCount=full.
132 //  * With offset 0 and pseudorandom (multiples of 4).
133 //  * With stride adding a padding of 0 and pseudorandom (multiples of 4).
134 //
135 // DRAW_INDIRECT_COUNT: same as indirect in two variants:
136 //  1. Passing the count in a buffer with a large maximum.
137 //  2. Passing a large value in the buffer and limiting it with the maximum.
138 
139 class MeshApiCase : public vkt::TestCase
140 {
141 public:
MeshApiCase(tcu::TestContext & testCtx,const std::string & name,const TestParams & params)142     MeshApiCase(tcu::TestContext &testCtx, const std::string &name, const TestParams &params)
143         : vkt::TestCase(testCtx, name)
144         , m_params(params)
145     {
146     }
~MeshApiCase(void)147     virtual ~MeshApiCase(void)
148     {
149     }
150 
151     void initPrograms(vk::SourceCollections &programCollection) const override;
152     void checkSupport(Context &context) const override;
153     TestInstance *createInstance(Context &context) const override;
154 
155 protected:
156     TestParams m_params;
157 };
158 
159 class MeshApiInstance : public vkt::TestInstance
160 {
161 public:
MeshApiInstance(Context & context,const TestParams & params)162     MeshApiInstance(Context &context, const TestParams &params) : vkt::TestInstance(context), m_params(params)
163     {
164     }
~MeshApiInstance(void)165     virtual ~MeshApiInstance(void)
166     {
167     }
168 
169     tcu::TestStatus iterate(void) override;
170 
171 protected:
172     TestParams m_params;
173 };
174 
createInstance(Context & context) const175 TestInstance *MeshApiCase::createInstance(Context &context) const
176 {
177     return new MeshApiInstance(context, m_params);
178 }
179 
180 struct PushConstantData
181 {
182     uint32_t width;
183     uint32_t height;
184     uint32_t dimMesh; // Set work group size in the X, Y or Z dimension depending on value (0, 1, 2).
185     uint32_t one;
186     uint32_t dimTask; // Same as dimMesh.
187 
getRangesvkt::MeshShader::__anond8eeac9e0111::PushConstantData188     std::vector<VkPushConstantRange> getRanges(bool includeTask) const
189     {
190         constexpr uint32_t offsetMesh = 0u;
191         constexpr uint32_t offsetTask = static_cast<uint32_t>(offsetof(PushConstantData, one));
192         constexpr uint32_t sizeMesh   = offsetTask;
193         constexpr uint32_t sizeTask   = static_cast<uint32_t>(sizeof(PushConstantData)) - offsetTask;
194 
195         const VkPushConstantRange meshRange = {
196             VK_SHADER_STAGE_MESH_BIT_EXT, // VkShaderStageFlags stageFlags;
197             offsetMesh,                   // uint32_t offset;
198             sizeMesh,                     // uint32_t size;
199         };
200         const VkPushConstantRange taskRange = {
201             VK_SHADER_STAGE_TASK_BIT_EXT, // VkShaderStageFlags stageFlags;
202             offsetTask,                   // uint32_t offset;
203             sizeTask,                     // uint32_t size;
204         };
205 
206         std::vector<VkPushConstantRange> ranges(1u, meshRange);
207         if (includeTask)
208             ranges.push_back(taskRange);
209         return ranges;
210     }
211 };
212 
initPrograms(vk::SourceCollections & programCollection) const213 void MeshApiCase::initPrograms(vk::SourceCollections &programCollection) const
214 {
215     const auto buildOptions = getMinMeshEXTBuildOptions(programCollection.usedVulkanVersion);
216 
217     const std::string taskDataDecl = "struct TaskData {\n"
218                                      "    uint blockNumber;\n"
219                                      "    uint blockRow;\n"
220                                      "};\n"
221                                      "taskPayloadSharedEXT TaskData td;\n";
222 
223     // Task shader if needed.
224     if (m_params.useTask)
225     {
226         std::ostringstream task;
227         task << "#version 460\n"
228              << "#extension GL_EXT_mesh_shader : enable\n"
229              << "\n"
230              << "layout (local_size_x=1) in;\n"
231              << "\n"
232              << "layout (push_constant, std430) uniform TaskPushConstantBlock {\n"
233              << "    layout (offset=12) uint one;\n"
234              << "    layout (offset=16) uint dimCoord;\n"
235              << "} pc;\n"
236              << "\n"
237              << taskDataDecl << "\n"
238              << "void main ()\n"
239              << "{\n"
240              << "    const uint workGroupID = ((pc.dimCoord == 2) ? gl_WorkGroupID.z : ((pc.dimCoord == 1) ? "
241                 "gl_WorkGroupID.y : gl_WorkGroupID.x));\n"
242              << "    td.blockNumber         = uint(gl_DrawID);\n"
243              << "    td.blockRow            = workGroupID;\n"
244              << "    EmitMeshTasksEXT(pc.one, pc.one, pc.one);"
245              << "}\n";
246         programCollection.glslSources.add("task") << glu::TaskSource(task.str()) << buildOptions;
247     }
248 
249     // Mesh shader.
250     {
251         std::ostringstream mesh;
252         mesh << "#version 460\n"
253              << "#extension GL_EXT_mesh_shader : enable\n"
254              << "\n"
255              << "// 32 local invocations in total.\n"
256              << "layout (local_size_x=4, local_size_y=2, local_size_z=4) in;\n"
257              << "layout (triangles) out;\n"
258              << "layout (max_vertices=96, max_primitives=32) out;\n"
259              << "\n"
260              << "layout (push_constant, std430) uniform MeshPushConstantBlock {\n"
261              << "    uint width;\n"
262              << "    uint height;\n"
263              << "    uint dimCoord;\n"
264              << "} pc;\n"
265              << "\n"
266              << "layout (location=0) perprimitiveEXT out vec4 primitiveColor[];\n"
267              << "\n"
268              << (m_params.useTask ? taskDataDecl : "") << "\n"
269              << "layout (set=0, binding=0, std430) readonly buffer BlockSizes {\n"
270              << "    uint blockSize[];\n"
271              << "} bsz;\n"
272              << "\n"
273              << "uint startOfBlock (uint blockNumber)\n"
274              << "{\n"
275              << "    uint start = 0;\n"
276              << "    for (uint i = 0; i < blockNumber; i++)\n"
277              << "        start += bsz.blockSize[i];\n"
278              << "    return start;\n"
279              << "}\n"
280              << "\n"
281              << "void main ()\n"
282              << "{\n"
283              << "    const uint workGroupID = ((pc.dimCoord == 2) ? gl_WorkGroupID.z : ((pc.dimCoord == 1) ? "
284                 "gl_WorkGroupID.y : gl_WorkGroupID.x));\n"
285              << "    const uint blockNumber = " << (m_params.useTask ? "td.blockNumber" : "uint(gl_DrawID)") << ";\n"
286              << "    const uint blockRow = " << (m_params.useTask ? "td.blockRow" : "workGroupID") << ";\n"
287              << "\n"
288              << "    // Each workgroup will fill one row, and each invocation will generate a\n"
289              << "    // triangle around the pixel center in each column.\n"
290              << "    const uint row = startOfBlock(blockNumber) + blockRow;\n"
291              << "    const uint col = gl_LocalInvocationIndex;\n"
292              << "\n"
293              << "    const float fHeight = float(pc.height);\n"
294              << "    const float fWidth = float(pc.width);\n"
295              << "\n"
296              << "    // Pixel coordinates, normalized.\n"
297              << "    const float rowNorm = (float(row) + 0.5) / fHeight;\n"
298              << "    const float colNorm = (float(col) + 0.5) / fWidth;\n"
299              << "\n"
300              << "    // Framebuffer coordinates.\n"
301              << "    const float coordX = (colNorm * 2.0) - 1.0;\n"
302              << "    const float coordY = (rowNorm * 2.0) - 1.0;\n"
303              << "\n"
304              << "    const float pixelWidth = 2.0 / fWidth;\n"
305              << "    const float pixelHeight = 2.0 / fHeight;\n"
306              << "\n"
307              << "    const float offsetX = pixelWidth / 2.0;\n"
308              << "    const float offsetY = pixelHeight / 2.0;\n"
309              << "\n"
310              << "    const uint baseIndex = col*3;\n"
311              << "    const uvec3 indices = uvec3(baseIndex, baseIndex + 1, baseIndex + 2);\n"
312              << "\n"
313              << "    SetMeshOutputsEXT(96u, 32u);\n"
314              << "    primitiveColor[col] = vec4(rowNorm, colNorm, 0.0, 1.0);\n"
315              << "    gl_PrimitiveTriangleIndicesEXT[col] = uvec3(indices.x, indices.y, indices.z);\n"
316              << "\n"
317              << "    gl_MeshVerticesEXT[indices.x].gl_Position = vec4(coordX - offsetX, coordY + offsetY, 0.0, 1.0);\n"
318              << "    gl_MeshVerticesEXT[indices.y].gl_Position = vec4(coordX + offsetX, coordY + offsetY, 0.0, 1.0);\n"
319              << "    gl_MeshVerticesEXT[indices.z].gl_Position = vec4(coordX, coordY - offsetY, 0.0, 1.0);\n"
320              << "}\n";
321         programCollection.glslSources.add("mesh") << glu::MeshSource(mesh.str()) << buildOptions;
322     }
323 
324     // Frag shader.
325     {
326         std::ostringstream frag;
327         frag << "#version 460\n"
328              << "#extension GL_EXT_mesh_shader : enable\n"
329              << "\n"
330              << "layout (location=0) perprimitiveEXT in vec4 primitiveColor;\n"
331              << "layout (location=0) out vec4 outColor;\n"
332              << "\n"
333              << "void main ()\n"
334              << "{\n"
335              << "    outColor = primitiveColor;\n"
336              << "}\n";
337         programCollection.glslSources.add("frag") << glu::FragmentSource(frag.str()) << buildOptions;
338     }
339 }
340 
checkSupport(Context & context) const341 void MeshApiCase::checkSupport(Context &context) const
342 {
343     checkTaskMeshShaderSupportEXT(context, m_params.useTask, true);
344 
345     // VUID-vkCmdDrawMeshTasksIndirectEXT-drawCount-02718
346     if (m_params.drawType == DrawType::DRAW_INDIRECT && m_params.drawCount > 1u)
347     {
348         context.requireDeviceCoreFeature(DEVICE_CORE_FEATURE_MULTI_DRAW_INDIRECT);
349     }
350 
351     // VUID-vkCmdDrawMeshTasksIndirectCountEXT-None-04445
352     if (m_params.drawType == DrawType::DRAW_INDIRECT_COUNT)
353         context.requireDeviceFunctionality("VK_KHR_draw_indirect_count");
354 }
355 
356 template <typename T>
makeStridedBuffer(const DeviceInterface & vkd,VkDevice device,Allocator & alloc,const std::vector<T> & elements,uint32_t offset,uint32_t stride,VkBufferUsageFlags usage,uint32_t endPadding)357 BufferWithMemoryPtr makeStridedBuffer(const DeviceInterface &vkd, VkDevice device, Allocator &alloc,
358                                       const std::vector<T> &elements, uint32_t offset, uint32_t stride,
359                                       VkBufferUsageFlags usage, uint32_t endPadding)
360 {
361     const auto elementSize  = static_cast<uint32_t>(sizeof(T));
362     const auto actualStride = std::max(elementSize, stride);
363     const auto bufferSize   = static_cast<size_t>(offset) + static_cast<size_t>(actualStride) * elements.size() +
364                             static_cast<size_t>(endPadding);
365     const auto bufferInfo = makeBufferCreateInfo(static_cast<VkDeviceSize>(bufferSize), usage);
366 
367     BufferWithMemoryPtr buffer(new BufferWithMemory(vkd, device, alloc, bufferInfo, MemoryRequirement::HostVisible));
368     auto &bufferAlloc   = buffer->getAllocation();
369     char *bufferDataPtr = reinterpret_cast<char *>(bufferAlloc.getHostPtr());
370 
371     char *itr = bufferDataPtr + offset;
372     for (const auto &elem : elements)
373     {
374         deMemcpy(itr, &elem, sizeof(elem));
375         itr += actualStride;
376     }
377     if (endPadding > 0u)
378         deMemset(itr, 0xFF, endPadding);
379 
380     flushAlloc(vkd, device, bufferAlloc);
381 
382     return buffer;
383 }
384 
getExtent()385 VkExtent3D getExtent()
386 {
387     return makeExtent3D(32u, 64u, 1u);
388 }
389 
getIndirectCommand(uint32_t blockSize,uint32_t dimCoord)390 VkDrawMeshTasksIndirectCommandEXT getIndirectCommand(uint32_t blockSize, uint32_t dimCoord)
391 {
392     VkDrawMeshTasksIndirectCommandEXT indirectCmd{1u, 1u, 1u};
393 
394     switch (dimCoord)
395     {
396     case 0u:
397         indirectCmd.groupCountX = blockSize;
398         break;
399     case 1u:
400         indirectCmd.groupCountY = blockSize;
401         break;
402     case 2u:
403         indirectCmd.groupCountZ = blockSize;
404         break;
405     default:
406         DE_ASSERT(false);
407         break;
408     }
409 
410     return indirectCmd;
411 }
412 
iterate(void)413 tcu::TestStatus MeshApiInstance::iterate(void)
414 {
415     const auto &vkd       = m_context.getDeviceInterface();
416     const auto device     = m_context.getDevice();
417     auto &alloc           = m_context.getDefaultAllocator();
418     const auto queueIndex = m_context.getUniversalQueueFamilyIndex();
419     const auto queue      = m_context.getUniversalQueue();
420 
421     const auto extent = getExtent();
422     const auto iExtent3D =
423         tcu::IVec3(static_cast<int>(extent.width), static_cast<int>(extent.height), static_cast<int>(extent.depth));
424     const auto iExtent2D  = tcu::IVec2(iExtent3D.x(), iExtent3D.y());
425     const auto format     = VK_FORMAT_R8G8B8A8_UNORM;
426     const auto tcuFormat  = mapVkFormat(format);
427     const auto colorUsage = (VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT | VK_IMAGE_USAGE_TRANSFER_SRC_BIT);
428     const auto colorSRR   = makeImageSubresourceRange(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 1u, 0u, 1u);
429     const tcu::Vec4 clearColor(0.0f, 0.0f, 0.0f, 1.0f);
430     const float colorThres = 0.005f; // 1/255 < 0.005 < 2/255
431     const tcu::Vec4 threshold(colorThres, colorThres, 0.0f, 0.0f);
432 
433     ImageWithMemoryPtr colorBuffer;
434     Move<VkImageView> colorBufferView;
435     {
436         const VkImageCreateInfo colorBufferInfo = {
437             VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // VkStructureType sType;
438             nullptr,                             // const void* pNext;
439             0u,                                  // VkImageCreateFlags flags;
440             VK_IMAGE_TYPE_2D,                    // VkImageType imageType;
441             format,                              // VkFormat format;
442             extent,                              // VkExtent3D extent;
443             1u,                                  // uint32_t mipLevels;
444             1u,                                  // uint32_t arrayLayers;
445             VK_SAMPLE_COUNT_1_BIT,               // VkSampleCountFlagBits samples;
446             VK_IMAGE_TILING_OPTIMAL,             // VkImageTiling tiling;
447             colorUsage,                          // VkImageUsageFlags usage;
448             VK_SHARING_MODE_EXCLUSIVE,           // VkSharingMode sharingMode;
449             0u,                                  // uint32_t queueFamilyIndexCount;
450             nullptr,                             // const uint32_t* pQueueFamilyIndices;
451             VK_IMAGE_LAYOUT_UNDEFINED,           // VkImageLayout initialLayout;
452         };
453         colorBuffer =
454             ImageWithMemoryPtr(new ImageWithMemory(vkd, device, alloc, colorBufferInfo, MemoryRequirement::Any));
455         colorBufferView = makeImageView(vkd, device, colorBuffer->get(), VK_IMAGE_VIEW_TYPE_2D, format, colorSRR);
456     }
457 
458     // Prepare buffer containing the array of block sizes.
459     de::Random rnd(m_params.seed);
460     std::vector<uint32_t> blockSizes;
461 
462     const uint32_t vectorSize = std::max(1u, m_params.drawCount);
463     const uint32_t largeDrawCount =
464         vectorSize + 1u; // The indirect buffer needs to have some padding at the end. See below.
465     const uint32_t evenBlockSize = extent.height / vectorSize;
466     uint32_t remainingRows       = extent.height;
467 
468     blockSizes.reserve(vectorSize);
469     for (uint32_t i = 0; i < vectorSize - 1u; ++i)
470     {
471         const auto blockSize = static_cast<uint32_t>(rnd.getInt(1, evenBlockSize));
472         remainingRows -= blockSize;
473         blockSizes.push_back(blockSize);
474     }
475     blockSizes.push_back(remainingRows);
476 
477     const auto blockSizesBufferSize = static_cast<VkDeviceSize>(de::dataSize(blockSizes));
478     BufferWithMemoryPtr blockSizesBuffer =
479         makeStridedBuffer(vkd, device, alloc, blockSizes, 0u, 0u, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, 0u);
480 
481     // Descriptor set layout, pool and set.
482     DescriptorSetLayoutBuilder layoutBuilder;
483     layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, VK_SHADER_STAGE_MESH_BIT_EXT);
484     const auto setLayout = layoutBuilder.build(vkd, device);
485 
486     DescriptorPoolBuilder poolBuilder;
487     poolBuilder.addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
488     const auto descriptorPool = poolBuilder.build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
489 
490     const auto descriptorSet = makeDescriptorSet(vkd, device, descriptorPool.get(), setLayout.get());
491 
492     // Update descriptor set.
493     {
494         DescriptorSetUpdateBuilder updateBuilder;
495 
496         const auto location             = DescriptorSetUpdateBuilder::Location::binding(0u);
497         const auto descriptorBufferInfo = makeDescriptorBufferInfo(blockSizesBuffer->get(), 0ull, blockSizesBufferSize);
498 
499         updateBuilder.writeSingle(descriptorSet.get(), location, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
500                                   &descriptorBufferInfo);
501         updateBuilder.update(vkd, device);
502     }
503 
504     // Pipeline layout.
505     PushConstantData pcData;
506     const auto pcRanges       = pcData.getRanges(m_params.useTask);
507     const auto pipelineLayout = makePipelineLayout(vkd, device, 1u, &setLayout.get(),
508                                                    static_cast<uint32_t>(pcRanges.size()), de::dataOrNull(pcRanges));
509 
510     // Push constants: choose used dimension coordinate pseudorandomly.
511     const auto dimCoord = rnd.getUint32() % 3u;
512 
513     pcData.width   = extent.width;
514     pcData.height  = extent.height;
515     pcData.dimMesh = dimCoord;
516     pcData.one     = 1u;
517     pcData.dimTask = dimCoord;
518 
519     // Render pass and framebuffer.
520     const auto renderPass = makeRenderPass(vkd, device, format);
521     const auto framebuffer =
522         makeFramebuffer(vkd, device, renderPass.get(), colorBufferView.get(), extent.width, extent.height);
523 
524     // Pipeline.
525     Move<VkShaderModule> taskModule;
526     Move<VkShaderModule> meshModule;
527     Move<VkShaderModule> fragModule;
528 
529     const auto &binaries = m_context.getBinaryCollection();
530     if (m_params.useTask)
531         taskModule = createShaderModule(vkd, device, binaries.get("task"));
532     meshModule = createShaderModule(vkd, device, binaries.get("mesh"));
533     fragModule = createShaderModule(vkd, device, binaries.get("frag"));
534 
535     const std::vector<VkViewport> viewports(1u, makeViewport(extent));
536     const std::vector<VkRect2D> scissors(1u, makeRect2D(extent));
537 
538     const auto pipeline = makeGraphicsPipeline(vkd, device, pipelineLayout.get(), taskModule.get(), meshModule.get(),
539                                                fragModule.get(), renderPass.get(), viewports, scissors);
540 
541     // Command pool and buffer.
542     const auto subpassContents =
543         (m_params.useSecondaryCmdBuffer ? VK_SUBPASS_CONTENTS_SECONDARY_COMMAND_BUFFERS : VK_SUBPASS_CONTENTS_INLINE);
544     const auto cmdPool          = makeCommandPool(vkd, device, queueIndex);
545     const auto primaryCmdBuffer = allocateCommandBuffer(vkd, device, cmdPool.get(), VK_COMMAND_BUFFER_LEVEL_PRIMARY);
546     const auto primary          = primaryCmdBuffer.get();
547     const auto secondaryCmdBuffer =
548         (m_params.useSecondaryCmdBuffer ?
549              allocateCommandBuffer(vkd, device, cmdPool.get(), VK_COMMAND_BUFFER_LEVEL_SECONDARY) :
550              Move<VkCommandBuffer>());
551     const auto secondary = secondaryCmdBuffer.get();
552     const auto rpCmdBuffer =
553         (m_params.useSecondaryCmdBuffer ? secondary : primary); // Holding the contents of the render pass commands.
554 
555     // Indirect and count buffers if needed.
556     BufferWithMemoryPtr indirectBuffer;
557     BufferWithMemoryPtr countBuffer;
558 
559     if (m_params.drawType != DrawType::DRAW)
560     {
561         // Indirect draws.
562         DE_ASSERT(static_cast<bool>(m_params.indirectArgs));
563         const auto &indirectArgs = m_params.indirectArgs.get();
564 
565         // Check stride and offset validity.
566         DE_ASSERT(indirectArgs.offset % 4u == 0u);
567         DE_ASSERT(indirectArgs.stride % 4u == 0u &&
568                   (indirectArgs.stride == 0u ||
569                    indirectArgs.stride >= static_cast<uint32_t>(sizeof(VkDrawMeshTasksIndirectCommandEXT))));
570 
571         // Prepare struct vector, which will be converted to a buffer with the proper stride and offset later.
572         std::vector<VkDrawMeshTasksIndirectCommandEXT> commands;
573         commands.reserve(blockSizes.size());
574 
575         std::transform(begin(blockSizes), end(blockSizes), std::back_inserter(commands),
576                        [dimCoord](uint32_t blockSize) { return getIndirectCommand(blockSize, dimCoord); });
577 
578         const auto padding = static_cast<uint32_t>(sizeof(VkDrawMeshTasksIndirectCommandEXT));
579         indirectBuffer     = makeStridedBuffer(vkd, device, alloc, commands, indirectArgs.offset, indirectArgs.stride,
580                                                VK_BUFFER_USAGE_INDIRECT_BUFFER_BIT, padding);
581 
582         // Prepare count buffer if needed.
583         if (m_params.drawType == DrawType::DRAW_INDIRECT_COUNT)
584         {
585             DE_ASSERT(static_cast<bool>(m_params.indirectCountLimit));
586             DE_ASSERT(static_cast<bool>(m_params.indirectCountOffset));
587 
588             const auto countBufferValue =
589                 ((m_params.indirectCountLimit.get() == IndirectCountLimitType::BUFFER_VALUE) ? m_params.drawCount :
590                                                                                                largeDrawCount);
591 
592             const std::vector<uint32_t> singleCount(1u, countBufferValue);
593             countBuffer =
594                 makeStridedBuffer(vkd, device, alloc, singleCount, m_params.indirectCountOffset.get(),
595                                   static_cast<uint32_t>(sizeof(uint32_t)), VK_BUFFER_USAGE_INDIRECT_BUFFER_BIT, 0u);
596         }
597     }
598 
599     // Submit commands.
600     beginCommandBuffer(vkd, primary);
601     beginRenderPass(vkd, primary, renderPass.get(), framebuffer.get(), scissors.at(0), clearColor, subpassContents);
602 
603     if (m_params.useSecondaryCmdBuffer)
604     {
605         const VkCommandBufferInheritanceInfo inheritanceInfo = {
606             VK_STRUCTURE_TYPE_COMMAND_BUFFER_INHERITANCE_INFO, // VkStructureType                  sType;
607             nullptr,                                           // const void*                      pNext;
608             renderPass.get(),                                  // VkRenderPass                     renderPass;
609             0u,                                                // uint32_t                         subpass;
610             framebuffer.get(),                                 // VkFramebuffer                    framebuffer;
611             VK_FALSE,                                          // VkBool32                         occlusionQueryEnable;
612             0u,                                                // VkQueryControlFlags              queryFlags;
613             0u,                                                // VkQueryPipelineStatisticFlags    pipelineStatistics;
614         };
615 
616         const VkCommandBufferUsageFlags cmdBufferFlags =
617             (VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT | VK_COMMAND_BUFFER_USAGE_RENDER_PASS_CONTINUE_BIT);
618         const VkCommandBufferBeginInfo beginInfo = {
619             VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO, // VkStructureType sType;
620             nullptr,                                     // const void* pNext;
621             cmdBufferFlags,                              // VkCommandBufferUsageFlags flags;
622             &inheritanceInfo,                            // const VkCommandBufferInheritanceInfo* pInheritanceInfo;
623         };
624 
625         vkd.beginCommandBuffer(secondary, &beginInfo);
626     }
627 
628     vkd.cmdBindDescriptorSets(rpCmdBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, pipelineLayout.get(), 0u, 1u,
629                               &descriptorSet.get(), 0u, nullptr);
630     {
631         const char *pcDataPtr = reinterpret_cast<const char *>(&pcData);
632         for (const auto &range : pcRanges)
633             vkd.cmdPushConstants(rpCmdBuffer, pipelineLayout.get(), range.stageFlags, range.offset, range.size,
634                                  pcDataPtr + range.offset);
635     }
636     vkd.cmdBindPipeline(rpCmdBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, pipeline.get());
637 
638     if (m_params.drawType == DrawType::DRAW)
639     {
640         const auto drawArgs = getIndirectCommand(m_params.drawCount, dimCoord);
641         vkd.cmdDrawMeshTasksEXT(rpCmdBuffer, drawArgs.groupCountX, drawArgs.groupCountY, drawArgs.groupCountZ);
642     }
643     else if (m_params.drawType == DrawType::DRAW_INDIRECT)
644     {
645         const auto &indirectArgs = m_params.indirectArgs.get();
646         vkd.cmdDrawMeshTasksIndirectEXT(rpCmdBuffer, indirectBuffer->get(), indirectArgs.offset, m_params.drawCount,
647                                         indirectArgs.stride);
648     }
649     else if (m_params.drawType == DrawType::DRAW_INDIRECT_COUNT)
650     {
651         const auto &indirectArgs        = m_params.indirectArgs.get();
652         const auto &indirectCountOffset = m_params.indirectCountOffset.get();
653         const auto &indirectCountLimit  = m_params.indirectCountLimit.get();
654 
655         const auto maxCount =
656             ((indirectCountLimit == IndirectCountLimitType::MAX_COUNT) ? m_params.drawCount : largeDrawCount);
657         vkd.cmdDrawMeshTasksIndirectCountEXT(rpCmdBuffer, indirectBuffer->get(), indirectArgs.offset,
658                                              countBuffer->get(), indirectCountOffset, maxCount, indirectArgs.stride);
659     }
660     else
661         DE_ASSERT(false);
662 
663     if (m_params.useSecondaryCmdBuffer)
664     {
665         endCommandBuffer(vkd, secondary);
666         vkd.cmdExecuteCommands(primary, 1u, &secondary);
667     }
668 
669     endRenderPass(vkd, primary);
670 
671     // Output buffer to extract the color buffer.
672     BufferWithMemoryPtr outBuffer;
673     void *outBufferData = nullptr;
674     {
675         const auto outBufferSize  = static_cast<VkDeviceSize>(static_cast<uint32_t>(tcu::getPixelSize(tcuFormat)) *
676                                                              extent.width * extent.height);
677         const auto outBufferUsage = VK_BUFFER_USAGE_TRANSFER_DST_BIT;
678         const auto outBufferInfo  = makeBufferCreateInfo(outBufferSize, outBufferUsage);
679 
680         outBuffer = BufferWithMemoryPtr(
681             new BufferWithMemory(vkd, device, alloc, outBufferInfo, MemoryRequirement::HostVisible));
682         outBufferData = outBuffer->getAllocation().getHostPtr();
683     }
684 
685     copyImageToBuffer(vkd, primary, colorBuffer->get(), outBuffer->get(), iExtent2D);
686     endCommandBuffer(vkd, primary);
687 
688     submitCommandsAndWait(vkd, device, queue, primary);
689 
690     // Generate reference image and compare.
691     {
692         auto &log            = m_context.getTestContext().getLog();
693         auto &outBufferAlloc = outBuffer->getAllocation();
694         tcu::ConstPixelBufferAccess result(tcuFormat, iExtent3D, outBufferData);
695         tcu::TextureLevel referenceLevel(tcuFormat, iExtent3D.x(), iExtent3D.y());
696         const auto reference = referenceLevel.getAccess();
697         const auto setName   = de::toString(m_params.drawType) + "_draw_count_" + de::toString(m_params.drawCount) +
698                              (m_params.useTask ? "_with_task" : "_no_task");
699         const auto fHeight = static_cast<float>(extent.height);
700         const auto fWidth  = static_cast<float>(extent.width);
701 
702         invalidateAlloc(vkd, device, outBufferAlloc);
703 
704         for (int y = 0; y < iExtent3D.y(); ++y)
705             for (int x = 0; x < iExtent3D.x(); ++x)
706             {
707                 const tcu::Vec4 refColor = ((m_params.drawCount == 0u || (m_params.drawType == DrawType::DRAW &&
708                                                                           y >= static_cast<int>(m_params.drawCount))) ?
709                                                 clearColor :
710                                                 tcu::Vec4(
711                                                     // These match the per-primitive color set by the mesh shader.
712                                                     (static_cast<float>(y) + 0.5f) / fHeight,
713                                                     (static_cast<float>(x) + 0.5f) / fWidth, 0.0f, 1.0f));
714                 reference.setPixel(refColor, x, y);
715             }
716 
717         if (!tcu::floatThresholdCompare(log, setName.c_str(), "", reference, result, threshold,
718                                         tcu::COMPARE_LOG_ON_ERROR))
719             return tcu::TestStatus::fail("Image comparison failed; check log for details");
720     }
721 
722     return tcu::TestStatus::pass("Pass");
723 }
724 
725 } // namespace
726 
createMeshShaderApiTestsEXT(tcu::TestContext & testCtx)727 tcu::TestCaseGroup *createMeshShaderApiTestsEXT(tcu::TestContext &testCtx)
728 {
729     GroupPtr mainGroup(new tcu::TestCaseGroup(testCtx, "api"));
730 
731     const DrawType drawCases[] = {
732         DrawType::DRAW,
733         DrawType::DRAW_INDIRECT,
734         DrawType::DRAW_INDIRECT_COUNT,
735     };
736 
737     const auto extent               = getExtent();
738     const uint32_t drawCountCases[] = {0u, 1u, 2u, extent.height / 2u, extent.height};
739 
740     const uint32_t normalStride = static_cast<uint32_t>(sizeof(VkDrawMeshTasksIndirectCommandEXT));
741     const uint32_t largeStride  = 2u * normalStride + 4u;
742     const uint32_t altOffset    = 20u;
743 
744     const struct
745     {
746         tcu::Maybe<IndirectArgs> indirectArgs;
747         const char *name;
748     } indirectArgsCases[] = {
749         {tcu::nothing<IndirectArgs>(), "no_indirect_args"},
750 
751         // Offset 0, varying strides.
752         {tcu::just(IndirectArgs{0u, 0u}), "offset_0_stride_0"},
753         {tcu::just(IndirectArgs{0u, normalStride}), "offset_0_stride_normal"},
754         {tcu::just(IndirectArgs{0u, largeStride}), "offset_0_stride_large"},
755 
756         // Nonzero offset, varying strides.
757         {tcu::just(IndirectArgs{altOffset, 0u}), "offset_alt_stride_0"},
758         {tcu::just(IndirectArgs{altOffset, normalStride}), "offset_alt_stride_normal"},
759         {tcu::just(IndirectArgs{altOffset, largeStride}), "offset_alt_stride_large"},
760     };
761 
762     const struct
763     {
764         tcu::Maybe<IndirectCountLimitType> limitType;
765         const char *name;
766     } countLimitCases[] = {
767         {tcu::nothing<IndirectCountLimitType>(), "no_count_limit"},
768         {tcu::just(IndirectCountLimitType::BUFFER_VALUE), "count_limit_buffer"},
769         {tcu::just(IndirectCountLimitType::MAX_COUNT), "count_limit_max_count"},
770     };
771 
772     const struct
773     {
774         tcu::Maybe<uint32_t> countOffset;
775         const char *name;
776     } countOffsetCases[] = {
777         {tcu::nothing<uint32_t>(), "no_count_offset"},
778         {tcu::just(uint32_t{0u}), "count_offset_0"},
779         {tcu::just(altOffset), "count_offset_alt"},
780     };
781 
782     const struct
783     {
784         bool useTask;
785         const char *name;
786     } taskCases[] = {
787         {false, "no_task_shader"},
788         {true, "with_task_shader"},
789     };
790 
791     const struct
792     {
793         bool secondaryCmd;
794         const char *suffix;
795     } cmdBufferCases[] = {
796         {false, ""},
797         {true, "_secondary_cmd"},
798     };
799 
800     uint32_t seed = 1628678795u;
801 
802     for (const auto &drawCase : drawCases)
803     {
804         const auto drawCaseName      = de::toString(drawCase);
805         const bool isIndirect        = (drawCase != DrawType::DRAW);
806         const bool isIndirectNoCount = (drawCase == DrawType::DRAW_INDIRECT);
807         const bool isIndirectCount   = (drawCase == DrawType::DRAW_INDIRECT_COUNT);
808 
809         GroupPtr drawGroup(new tcu::TestCaseGroup(testCtx, drawCaseName.c_str()));
810 
811         for (const auto &drawCountCase : drawCountCases)
812         {
813             const auto drawCountName = "draw_count_" + de::toString(drawCountCase);
814             GroupPtr drawCountGroup(new tcu::TestCaseGroup(testCtx, drawCountName.c_str()));
815 
816             for (const auto &indirectArgsCase : indirectArgsCases)
817             {
818                 const bool hasIndirectArgs = static_cast<bool>(indirectArgsCase.indirectArgs);
819                 const bool strideZero      = (hasIndirectArgs && indirectArgsCase.indirectArgs.get().stride == 0u);
820 
821                 if (isIndirect != hasIndirectArgs)
822                     continue;
823 
824                 if (((isIndirectNoCount && drawCountCase > 1u) || isIndirectCount) && strideZero)
825                     continue;
826 
827                 GroupPtr indirectArgsGroup(new tcu::TestCaseGroup(testCtx, indirectArgsCase.name));
828 
829                 for (const auto &countLimitCase : countLimitCases)
830                 {
831                     const bool hasCountLimit = static_cast<bool>(countLimitCase.limitType);
832 
833                     if (isIndirectCount != hasCountLimit)
834                         continue;
835 
836                     GroupPtr countLimitGroup(new tcu::TestCaseGroup(testCtx, countLimitCase.name));
837 
838                     for (const auto &countOffsetCase : countOffsetCases)
839                     {
840                         const bool hasCountOffsetType = static_cast<bool>(countOffsetCase.countOffset);
841 
842                         if (isIndirectCount != hasCountOffsetType)
843                             continue;
844 
845                         GroupPtr countOffsetGroup(new tcu::TestCaseGroup(testCtx, countOffsetCase.name));
846 
847                         for (const auto &taskCase : taskCases)
848                         {
849                             for (const auto &cmdBufferCase : cmdBufferCases)
850                             {
851                                 const auto testName     = std::string(taskCase.name) + cmdBufferCase.suffix;
852                                 const TestParams params = {
853                                     drawCase,                      // DrawType drawType;
854                                     seed++,                        // uint32_t seed;
855                                     drawCountCase,                 // uint32_t drawCount;
856                                     indirectArgsCase.indirectArgs, // tcu::Maybe<IndirectArgs> indirectArgs;
857                                     countLimitCase.limitType, // tcu::Maybe<IndirectCountLimitType> indirectCountLimit;
858                                     countOffsetCase.countOffset, // tcu::Maybe<uint32_t> indirectCountOffset;
859                                     taskCase.useTask,            // bool useTask;
860                                     cmdBufferCase.secondaryCmd,  // bool useSecondaryCmdBuffer;
861                                 };
862 
863                                 countOffsetGroup->addChild(new MeshApiCase(testCtx, testName, params));
864                             }
865                         }
866 
867                         countLimitGroup->addChild(countOffsetGroup.release());
868                     }
869 
870                     indirectArgsGroup->addChild(countLimitGroup.release());
871                 }
872 
873                 drawCountGroup->addChild(indirectArgsGroup.release());
874             }
875 
876             drawGroup->addChild(drawCountGroup.release());
877         }
878 
879         mainGroup->addChild(drawGroup.release());
880     }
881 
882     return mainGroup.release();
883 }
884 
885 } // namespace MeshShader
886 } // namespace vkt
887