xref: /aosp_15_r20/external/deqp/external/vulkancts/modules/vulkan/mesh_shader/vktMeshShaderSmokeTests.cpp (revision 35238bce31c2a825756842865a792f8cf7f89930)
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2021 The Khronos Group Inc.
6  * Copyright (c) 2021 Valve Corporation.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  *//*!
21  * \file
22  * \brief Mesh Shader Smoke Tests
23  *//*--------------------------------------------------------------------*/
24 
25 #include "vktMeshShaderSmokeTests.hpp"
26 #include "vktMeshShaderUtil.hpp"
27 #include "vktTestCase.hpp"
28 #include "vktTestCaseUtil.hpp"
29 
30 #include "vkBuilderUtil.hpp"
31 #include "vkImageWithMemory.hpp"
32 #include "vkBufferWithMemory.hpp"
33 #include "vkObjUtil.hpp"
34 #include "vkTypeUtil.hpp"
35 #include "vkCmdUtil.hpp"
36 #include "vkImageUtil.hpp"
37 
38 #include "tcuImageCompare.hpp"
39 #include "tcuTestLog.hpp"
40 #include "tcuTextureUtil.hpp"
41 
42 #include <utility>
43 #include <vector>
44 #include <string>
45 #include <sstream>
46 
47 namespace vkt
48 {
49 namespace MeshShader
50 {
51 
52 namespace
53 {
54 
55 using GroupPtr = de::MovePtr<tcu::TestCaseGroup>;
56 
57 using namespace vk;
58 
commonMeshFragShader()59 std::string commonMeshFragShader()
60 {
61     std::string frag = "#version 450\n"
62                        "#extension GL_NV_mesh_shader : enable\n"
63                        "\n"
64                        "layout (location=0) in perprimitiveNV vec4 triangleColor;\n"
65                        "layout (location=0) out vec4 outColor;\n"
66                        "\n"
67                        "void main ()\n"
68                        "{\n"
69                        "    outColor = triangleColor;\n"
70                        "}\n";
71     return frag;
72 }
73 
74 struct MeshTriangleRendererParams
75 {
76     std::vector<tcu::Vec4> vertexCoords;
77     std::vector<uint32_t> vertexIndices;
78     uint32_t taskCount;
79     tcu::Vec4 expectedColor;
80 
MeshTriangleRendererParamsvkt::MeshShader::__anone08c37f20111::MeshTriangleRendererParams81     MeshTriangleRendererParams(std::vector<tcu::Vec4> vertexCoords_, std::vector<uint32_t> vertexIndices_,
82                                uint32_t taskCount_, const tcu::Vec4 &expectedColor_)
83         : vertexCoords(std::move(vertexCoords_))
84         , vertexIndices(std::move(vertexIndices_))
85         , taskCount(taskCount_)
86         , expectedColor(expectedColor_)
87     {
88     }
89 
MeshTriangleRendererParamsvkt::MeshShader::__anone08c37f20111::MeshTriangleRendererParams90     MeshTriangleRendererParams(MeshTriangleRendererParams &&other)
91         : MeshTriangleRendererParams(std::move(other.vertexCoords), std::move(other.vertexIndices), other.taskCount,
92                                      other.expectedColor)
93     {
94     }
95 };
96 
97 class MeshOnlyTriangleCase : public vkt::TestCase
98 {
99 public:
MeshOnlyTriangleCase(tcu::TestContext & testCtx,const std::string & name)100     MeshOnlyTriangleCase(tcu::TestContext &testCtx, const std::string &name) : vkt::TestCase(testCtx, name)
101     {
102     }
~MeshOnlyTriangleCase(void)103     virtual ~MeshOnlyTriangleCase(void)
104     {
105     }
106 
107     void initPrograms(vk::SourceCollections &programCollection) const override;
108     TestInstance *createInstance(Context &context) const override;
109     void checkSupport(Context &context) const override;
110 };
111 
112 class MeshTaskTriangleCase : public vkt::TestCase
113 {
114 public:
MeshTaskTriangleCase(tcu::TestContext & testCtx,const std::string & name)115     MeshTaskTriangleCase(tcu::TestContext &testCtx, const std::string &name) : vkt::TestCase(testCtx, name)
116     {
117     }
~MeshTaskTriangleCase(void)118     virtual ~MeshTaskTriangleCase(void)
119     {
120     }
121 
122     void initPrograms(vk::SourceCollections &programCollection) const override;
123     TestInstance *createInstance(Context &context) const override;
124     void checkSupport(Context &context) const override;
125 };
126 
127 // Note: not actually task-only. The task shader will not emit mesh shader work groups.
128 class TaskOnlyTriangleCase : public vkt::TestCase
129 {
130 public:
TaskOnlyTriangleCase(tcu::TestContext & testCtx,const std::string & name)131     TaskOnlyTriangleCase(tcu::TestContext &testCtx, const std::string &name) : vkt::TestCase(testCtx, name)
132     {
133     }
~TaskOnlyTriangleCase(void)134     virtual ~TaskOnlyTriangleCase(void)
135     {
136     }
137 
138     void initPrograms(vk::SourceCollections &programCollection) const override;
139     TestInstance *createInstance(Context &context) const override;
140     void checkSupport(Context &context) const override;
141 };
142 
143 class MeshTriangleRenderer : public vkt::TestInstance
144 {
145 public:
MeshTriangleRenderer(Context & context,MeshTriangleRendererParams params)146     MeshTriangleRenderer(Context &context, MeshTriangleRendererParams params)
147         : vkt::TestInstance(context)
148         , m_params(std::move(params))
149     {
150     }
~MeshTriangleRenderer(void)151     virtual ~MeshTriangleRenderer(void)
152     {
153     }
154 
155     tcu::TestStatus iterate(void) override;
156 
157 protected:
158     MeshTriangleRendererParams m_params;
159 };
160 
checkSupport(Context & context) const161 void MeshOnlyTriangleCase::checkSupport(Context &context) const
162 {
163     checkTaskMeshShaderSupportNV(context, false, true);
164 }
165 
checkSupport(Context & context) const166 void MeshTaskTriangleCase::checkSupport(Context &context) const
167 {
168     checkTaskMeshShaderSupportNV(context, true, true);
169 }
170 
checkSupport(Context & context) const171 void TaskOnlyTriangleCase::checkSupport(Context &context) const
172 {
173     checkTaskMeshShaderSupportNV(context, true, true);
174 }
175 
initPrograms(SourceCollections & dst) const176 void MeshOnlyTriangleCase::initPrograms(SourceCollections &dst) const
177 {
178     std::ostringstream mesh;
179     mesh << "#version 450\n"
180          << "#extension GL_NV_mesh_shader : enable\n"
181          << "\n"
182          // We will actually output a single triangle and most invocations will do no work.
183          << "layout(local_size_x=32) in;\n"
184          << "layout(triangles) out;\n"
185          << "layout(max_vertices=256, max_primitives=256) out;\n"
186          << "\n"
187          // Unique vertex coordinates.
188          << "layout (set=0, binding=0) uniform CoordsBuffer {\n"
189          << "    vec4 coords[3];\n"
190          << "} cb;\n"
191          // Unique vertex indices.
192          << "layout (set=0, binding=1, std430) readonly buffer IndexBuffer {\n"
193          << "    uint indices[3];\n"
194          << "} ib;\n"
195          << "\n"
196          // Triangle color.
197          << "layout (location=0) out perprimitiveNV vec4 triangleColor[];\n"
198          << "\n"
199          << "void main ()\n"
200          << "{\n"
201          << "    gl_PrimitiveCountNV = 1u;\n"
202          << "    triangleColor[0] = vec4(0.0, 0.0, 1.0, 1.0);\n"
203          << "\n"
204          << "    const uint vertex = gl_LocalInvocationIndex;\n"
205          << "    if (vertex < 3u)\n"
206          << "    {\n"
207          << "        const uint vertexIndex = ib.indices[vertex];\n"
208          << "        gl_PrimitiveIndicesNV[vertex] = vertexIndex;\n"
209          << "        gl_MeshVerticesNV[vertexIndex].gl_Position = cb.coords[vertexIndex];\n"
210          << "    }\n"
211          << "}\n";
212     dst.glslSources.add("mesh") << glu::MeshSource(mesh.str());
213 
214     dst.glslSources.add("frag") << glu::FragmentSource(commonMeshFragShader());
215 }
216 
initPrograms(SourceCollections & dst) const217 void MeshTaskTriangleCase::initPrograms(SourceCollections &dst) const
218 {
219     std::string taskDataDecl = "taskNV TaskData {\n"
220                                "    uint triangleIndex;\n"
221                                "} td;\n";
222 
223     std::ostringstream task;
224     task
225         // Each work group spawns 1 task each (2 in total) and each task will draw 1 triangle.
226         << "#version 450\n"
227         << "#extension GL_NV_mesh_shader : enable\n"
228         << "\n"
229         << "layout(local_size_x=32) in;\n"
230         << "\n"
231         << "out " << taskDataDecl << "\n"
232         << "void main ()\n"
233         << "{\n"
234         << "    if (gl_LocalInvocationIndex == 0u)\n"
235         << "    {\n"
236         << "        gl_TaskCountNV = 1u;\n"
237         << "        td.triangleIndex = gl_WorkGroupID.x;\n"
238         << "    }\n"
239         << "}\n";
240     dst.glslSources.add("task") << glu::TaskSource(task.str());
241 
242     std::ostringstream mesh;
243     mesh << "#version 450\n"
244          << "#extension GL_NV_mesh_shader : enable\n"
245          << "\n"
246          // We will actually output a single triangle and most invocations will do no work.
247          << "layout(local_size_x=32) in;\n"
248          << "layout(triangles) out;\n"
249          << "layout(max_vertices=256, max_primitives=256) out;\n"
250          << "\n"
251          // Unique vertex coordinates.
252          << "layout (set=0, binding=0) uniform CoordsBuffer {\n"
253          << "    vec4 coords[4];\n"
254          << "} cb;\n"
255          // Unique vertex indices.
256          << "layout (set=0, binding=1, std430) readonly buffer IndexBuffer {\n"
257          << "    uint indices[6];\n"
258          << "} ib;\n"
259          << "\n"
260          // Triangle color.
261          << "layout (location=0) out perprimitiveNV vec4 triangleColor[];\n"
262          << "\n"
263          << "in " << taskDataDecl << "\n"
264          << "void main ()\n"
265          << "{\n"
266          << "    if (gl_LocalInvocationIndex == 0u)\n"
267          << "    {\n"
268          << "        gl_PrimitiveCountNV = 1u;\n"
269          << "        triangleColor[0] = vec4(0.0, 0.0, 1.0, 1.0);\n"
270          << "    }\n"
271          << "\n"
272          // Each "active" invocation will copy one vertex.
273          << "    if (gl_LocalInvocationIndex < 3u)\n"
274          << "    {\n"
275          << "\n"
276          << "        const uint triangleVertex = gl_LocalInvocationIndex;\n"
277          << "        const uint coordsIndex    = ib.indices[td.triangleIndex * 3u + triangleVertex];\n"
278          << "\n"
279          // Copy vertex coordinates.
280          << "        gl_MeshVerticesNV[triangleVertex].gl_Position = cb.coords[coordsIndex];\n"
281          // Index renumbering: final indices will always be 0, 1, 2.
282          << "        gl_PrimitiveIndicesNV[triangleVertex] = triangleVertex;\n"
283          << "    }\n"
284          << "}\n";
285     dst.glslSources.add("mesh") << glu::MeshSource(mesh.str());
286 
287     dst.glslSources.add("frag") << glu::FragmentSource(commonMeshFragShader());
288 }
289 
initPrograms(SourceCollections & dst) const290 void TaskOnlyTriangleCase::initPrograms(SourceCollections &dst) const
291 {
292     // The task shader does not spawn any mesh shader invocations.
293     std::ostringstream task;
294     task << "#version 450\n"
295          << "#extension GL_NV_mesh_shader : enable\n"
296          << "\n"
297          << "layout(local_size_x=1) in;\n"
298          << "\n"
299          << "void main ()\n"
300          << "{\n"
301          << "    gl_TaskCountNV = 0u;\n"
302          << "}\n";
303     dst.glslSources.add("task") << glu::TaskSource(task.str());
304 
305     // Same shader as the mesh only case, but it should not be launched.
306     std::ostringstream mesh;
307     mesh << "#version 450\n"
308          << "#extension GL_NV_mesh_shader : enable\n"
309          << "\n"
310          << "layout(local_size_x=32) in;\n"
311          << "layout(triangles) out;\n"
312          << "layout(max_vertices=256, max_primitives=256) out;\n"
313          << "\n"
314          << "layout (set=0, binding=0) uniform CoordsBuffer {\n"
315          << "    vec4 coords[3];\n"
316          << "} cb;\n"
317          << "layout (set=0, binding=1, std430) readonly buffer IndexBuffer {\n"
318          << "    uint indices[3];\n"
319          << "} ib;\n"
320          << "\n"
321          << "layout (location=0) out perprimitiveNV vec4 triangleColor[];\n"
322          << "\n"
323          << "void main ()\n"
324          << "{\n"
325          << "    gl_PrimitiveCountNV = 1u;\n"
326          << "    triangleColor[0] = vec4(0.0, 0.0, 1.0, 1.0);\n"
327          << "\n"
328          << "    const uint vertex = gl_LocalInvocationIndex;\n"
329          << "    if (vertex < 3u)\n"
330          << "    {\n"
331          << "        const uint vertexIndex = ib.indices[vertex];\n"
332          << "        gl_PrimitiveIndicesNV[vertex] = vertexIndex;\n"
333          << "        gl_MeshVerticesNV[vertexIndex].gl_Position = cb.coords[vertexIndex];\n"
334          << "    }\n"
335          << "}\n";
336     dst.glslSources.add("mesh") << glu::MeshSource(mesh.str());
337 
338     dst.glslSources.add("frag") << glu::FragmentSource(commonMeshFragShader());
339 }
340 
createInstance(Context & context) const341 TestInstance *MeshOnlyTriangleCase::createInstance(Context &context) const
342 {
343     const std::vector<tcu::Vec4> vertexCoords = {
344         tcu::Vec4(-1.0f, -1.0f, 0.0f, 1.0f),
345         tcu::Vec4(-1.0f, 3.0f, 0.0f, 1.0f),
346         tcu::Vec4(3.0f, -1.0f, 0.0f, 1.0f),
347     };
348     const std::vector<uint32_t> vertexIndices = {0u, 1u, 2u};
349     MeshTriangleRendererParams params(std::move(vertexCoords), std::move(vertexIndices), 1u,
350                                       tcu::Vec4(0.0f, 0.0f, 1.0f, 1.0f));
351 
352     return new MeshTriangleRenderer(context, std::move(params));
353 }
354 
createInstance(Context & context) const355 TestInstance *MeshTaskTriangleCase::createInstance(Context &context) const
356 {
357     const std::vector<tcu::Vec4> vertexCoords = {
358         tcu::Vec4(-1.0f, -1.0f, 0.0f, 1.0f),
359         tcu::Vec4(-1.0f, 1.0f, 0.0f, 1.0f),
360         tcu::Vec4(1.0f, -1.0f, 0.0f, 1.0f),
361         tcu::Vec4(1.0f, 1.0f, 0.0f, 1.0f),
362     };
363     const std::vector<uint32_t> vertexIndices = {2u, 0u, 1u, 1u, 3u, 2u};
364     MeshTriangleRendererParams params(std::move(vertexCoords), std::move(vertexIndices), 2u,
365                                       tcu::Vec4(0.0f, 0.0f, 1.0f, 1.0f));
366 
367     return new MeshTriangleRenderer(context, std::move(params));
368 }
369 
createInstance(Context & context) const370 TestInstance *TaskOnlyTriangleCase::createInstance(Context &context) const
371 {
372     const std::vector<tcu::Vec4> vertexCoords = {
373         tcu::Vec4(-1.0f, -1.0f, 0.0f, 1.0f),
374         tcu::Vec4(-1.0f, 3.0f, 0.0f, 1.0f),
375         tcu::Vec4(3.0f, -1.0f, 0.0f, 1.0f),
376     };
377     const std::vector<uint32_t> vertexIndices = {0u, 1u, 2u};
378     // Note we expect the clear color.
379     MeshTriangleRendererParams params(std::move(vertexCoords), std::move(vertexIndices), 1u,
380                                       tcu::Vec4(0.0f, 0.0f, 0.0f, 1.0f));
381 
382     return new MeshTriangleRenderer(context, std::move(params));
383 }
384 
iterate()385 tcu::TestStatus MeshTriangleRenderer::iterate()
386 {
387     const auto &vkd   = m_context.getDeviceInterface();
388     const auto device = m_context.getDevice();
389     auto &alloc       = m_context.getDefaultAllocator();
390     const auto qIndex = m_context.getUniversalQueueFamilyIndex();
391     const auto queue  = m_context.getUniversalQueue();
392 
393     const auto vertexBufferStages = VK_SHADER_STAGE_MESH_BIT_NV;
394     const auto vertexBufferSize   = static_cast<VkDeviceSize>(de::dataSize(m_params.vertexCoords));
395     const auto vertexBufferUsage  = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT;
396     const auto vertexBufferLoc    = DescriptorSetUpdateBuilder::Location::binding(0u);
397     const auto vertexBufferType   = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
398 
399     const auto indexBufferStages = VK_SHADER_STAGE_MESH_BIT_NV;
400     const auto indexBufferSize   = static_cast<VkDeviceSize>(de::dataSize(m_params.vertexIndices));
401     const auto indexBufferUsage  = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
402     const auto indexBufferLoc    = DescriptorSetUpdateBuilder::Location::binding(1u);
403     const auto indexBufferType   = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
404 
405     // Vertex buffer.
406     const auto vertexBufferInfo = makeBufferCreateInfo(vertexBufferSize, vertexBufferUsage);
407     BufferWithMemory vertexBuffer(vkd, device, alloc, vertexBufferInfo, MemoryRequirement::HostVisible);
408     auto &vertexBufferAlloc   = vertexBuffer.getAllocation();
409     void *vertexBufferDataPtr = vertexBufferAlloc.getHostPtr();
410 
411     deMemcpy(vertexBufferDataPtr, m_params.vertexCoords.data(), static_cast<size_t>(vertexBufferSize));
412     flushAlloc(vkd, device, vertexBufferAlloc);
413 
414     // Index buffer.
415     const auto indexBufferInfo = makeBufferCreateInfo(indexBufferSize, indexBufferUsage);
416     BufferWithMemory indexBuffer(vkd, device, alloc, indexBufferInfo, MemoryRequirement::HostVisible);
417     auto &indexBufferAlloc   = indexBuffer.getAllocation();
418     void *indexBufferDataPtr = indexBufferAlloc.getHostPtr();
419 
420     deMemcpy(indexBufferDataPtr, m_params.vertexIndices.data(), static_cast<size_t>(indexBufferSize));
421     flushAlloc(vkd, device, indexBufferAlloc);
422 
423     // Color buffer.
424     const auto colorBufferFormat = VK_FORMAT_R8G8B8A8_UNORM;
425     const auto colorBufferExtent = makeExtent3D(8u, 8u, 1u);
426     const auto colorBufferUsage  = (VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT | VK_IMAGE_USAGE_TRANSFER_SRC_BIT);
427 
428     const VkImageCreateInfo colorBufferInfo = {
429         VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // VkStructureType sType;
430         nullptr,                             // const void* pNext;
431         0u,                                  // VkImageCreateFlags flags;
432         VK_IMAGE_TYPE_2D,                    // VkImageType imageType;
433         colorBufferFormat,                   // VkFormat format;
434         colorBufferExtent,                   // VkExtent3D extent;
435         1u,                                  // uint32_t mipLevels;
436         1u,                                  // uint32_t arrayLayers;
437         VK_SAMPLE_COUNT_1_BIT,               // VkSampleCountFlagBits samples;
438         VK_IMAGE_TILING_OPTIMAL,             // VkImageTiling tiling;
439         colorBufferUsage,                    // VkImageUsageFlags usage;
440         VK_SHARING_MODE_EXCLUSIVE,           // VkSharingMode sharingMode;
441         0u,                                  // uint32_t queueFamilyIndexCount;
442         nullptr,                             // const uint32_t* pQueueFamilyIndices;
443         VK_IMAGE_LAYOUT_UNDEFINED,           // VkImageLayout initialLayout;
444     };
445     ImageWithMemory colorBuffer(vkd, device, alloc, colorBufferInfo, MemoryRequirement::Any);
446 
447     const auto colorSRR = makeImageSubresourceRange(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 1u, 0u, 1u);
448     const auto colorBufferView =
449         makeImageView(vkd, device, colorBuffer.get(), VK_IMAGE_VIEW_TYPE_2D, colorBufferFormat, colorSRR);
450 
451     // Render pass.
452     const auto renderPass = makeRenderPass(vkd, device, colorBufferFormat);
453 
454     // Framebuffer.
455     const auto framebuffer = makeFramebuffer(vkd, device, renderPass.get(), colorBufferView.get(),
456                                              colorBufferExtent.width, colorBufferExtent.height);
457 
458     // Set layout.
459     DescriptorSetLayoutBuilder layoutBuilder;
460     layoutBuilder.addSingleBinding(vertexBufferType, vertexBufferStages);
461     layoutBuilder.addSingleBinding(indexBufferType, indexBufferStages);
462     const auto setLayout = layoutBuilder.build(vkd, device);
463 
464     // Descriptor pool.
465     DescriptorPoolBuilder poolBuilder;
466     poolBuilder.addType(vertexBufferType);
467     poolBuilder.addType(indexBufferType);
468     const auto descriptorPool = poolBuilder.build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
469 
470     // Descriptor set.
471     const auto descriptorSet = makeDescriptorSet(vkd, device, descriptorPool.get(), setLayout.get());
472 
473     // Update descriptor set.
474     DescriptorSetUpdateBuilder updateBuilder;
475     const auto vertexBufferDescInfo = makeDescriptorBufferInfo(vertexBuffer.get(), 0ull, vertexBufferSize);
476     const auto indexBufferDescInfo  = makeDescriptorBufferInfo(indexBuffer.get(), 0ull, indexBufferSize);
477     updateBuilder.writeSingle(descriptorSet.get(), vertexBufferLoc, vertexBufferType, &vertexBufferDescInfo);
478     updateBuilder.writeSingle(descriptorSet.get(), indexBufferLoc, indexBufferType, &indexBufferDescInfo);
479     updateBuilder.update(vkd, device);
480 
481     // Pipeline layout.
482     const auto pipelineLayout = makePipelineLayout(vkd, device, setLayout.get());
483 
484     // Shader modules.
485     Move<VkShaderModule> taskModule;
486     const auto &binaries = m_context.getBinaryCollection();
487 
488     if (binaries.contains("task"))
489         taskModule = createShaderModule(vkd, device, binaries.get("task"), 0u);
490     const auto meshModule = createShaderModule(vkd, device, binaries.get("mesh"), 0u);
491     const auto fragModule = createShaderModule(vkd, device, binaries.get("frag"), 0u);
492 
493     // Graphics pipeline.
494     std::vector<VkViewport> viewports(1u, makeViewport(colorBufferExtent));
495     std::vector<VkRect2D> scissors(1u, makeRect2D(colorBufferExtent));
496     const auto pipeline = makeGraphicsPipeline(vkd, device, pipelineLayout.get(), taskModule.get(), meshModule.get(),
497                                                fragModule.get(), renderPass.get(), viewports, scissors);
498 
499     // Command pool and buffer.
500     const auto cmdPool      = makeCommandPool(vkd, device, qIndex);
501     const auto cmdBufferPtr = allocateCommandBuffer(vkd, device, cmdPool.get(), VK_COMMAND_BUFFER_LEVEL_PRIMARY);
502     const auto cmdBuffer    = cmdBufferPtr.get();
503 
504     // Output buffer.
505     const auto tcuFormat      = mapVkFormat(colorBufferFormat);
506     const auto outBufferSize  = static_cast<VkDeviceSize>(static_cast<uint32_t>(tcu::getPixelSize(tcuFormat)) *
507                                                          colorBufferExtent.width * colorBufferExtent.height);
508     const auto outBufferUsage = VK_BUFFER_USAGE_TRANSFER_DST_BIT;
509     const auto outBufferInfo  = makeBufferCreateInfo(outBufferSize, outBufferUsage);
510     BufferWithMemory outBuffer(vkd, device, alloc, outBufferInfo, MemoryRequirement::HostVisible);
511     auto &outBufferAlloc = outBuffer.getAllocation();
512     void *outBufferData  = outBufferAlloc.getHostPtr();
513 
514     // Draw triangle.
515     beginCommandBuffer(vkd, cmdBuffer);
516     beginRenderPass(vkd, cmdBuffer, renderPass.get(), framebuffer.get(), scissors.at(0),
517                     tcu::Vec4(0.0f, 0.0f, 0.0f, 1.0f) /*clear color*/);
518     vkd.cmdBindDescriptorSets(cmdBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, pipelineLayout.get(), 0u, 1u,
519                               &descriptorSet.get(), 0u, nullptr);
520     vkd.cmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, pipeline.get());
521     vkd.cmdDrawMeshTasksNV(cmdBuffer, m_params.taskCount, 0u);
522     endRenderPass(vkd, cmdBuffer);
523 
524     // Copy color buffer to output buffer.
525     const tcu::IVec3 imageDim(static_cast<int>(colorBufferExtent.width), static_cast<int>(colorBufferExtent.height),
526                               static_cast<int>(colorBufferExtent.depth));
527     const tcu::IVec2 imageSize(imageDim.x(), imageDim.y());
528 
529     copyImageToBuffer(vkd, cmdBuffer, colorBuffer.get(), outBuffer.get(), imageSize);
530     endCommandBuffer(vkd, cmdBuffer);
531     submitCommandsAndWait(vkd, device, queue, cmdBuffer);
532 
533     // Invalidate alloc.
534     invalidateAlloc(vkd, device, outBufferAlloc);
535     tcu::ConstPixelBufferAccess outPixels(tcuFormat, imageDim, outBufferData);
536 
537     auto &log = m_context.getTestContext().getLog();
538     const tcu::Vec4 threshold(0.0f); // The color can be represented exactly.
539 
540     if (!tcu::floatThresholdCompare(log, "Result", "", m_params.expectedColor, outPixels, threshold,
541                                     tcu::COMPARE_LOG_ON_ERROR))
542         return tcu::TestStatus::fail("Failed; check log for details");
543 
544     return tcu::TestStatus::pass("Pass");
545 }
546 
gradientImageExtent()547 VkExtent3D gradientImageExtent()
548 {
549     return makeExtent3D(256u, 256u, 1u);
550 }
551 
checkMeshSupport(Context & context,tcu::Maybe<FragmentSize> fragmentSize)552 void checkMeshSupport(Context &context, tcu::Maybe<FragmentSize> fragmentSize)
553 {
554     DE_UNREF(fragmentSize);
555     checkTaskMeshShaderSupportNV(context, false, true);
556 }
557 
initGradientPrograms(vk::SourceCollections & programCollection,tcu::Maybe<FragmentSize> fragmentSize)558 void initGradientPrograms(vk::SourceCollections &programCollection, tcu::Maybe<FragmentSize> fragmentSize)
559 {
560     const auto extent = gradientImageExtent();
561 
562     std::ostringstream frag;
563     frag << "#version 450\n"
564          << "\n"
565          << "layout (location=0) in  vec4 inColor;\n"
566          << "layout (location=0) out vec4 outColor;\n"
567          << "\n"
568          << "void main ()\n"
569          << "{\n"
570          << "    outColor = inColor;\n"
571          << "}\n";
572     programCollection.glslSources.add("frag") << glu::FragmentSource(frag.str());
573 
574     const auto useFragmentSize = static_cast<bool>(fragmentSize);
575 
576     if (!useFragmentSize)
577     {
578         std::ostringstream mesh;
579         mesh << "#version 450\n"
580              << "#extension GL_NV_mesh_shader : enable\n"
581              << "\n"
582              << "layout(local_size_x=4) in;\n"
583              << "layout(triangles) out;\n"
584              << "layout(max_vertices=256, max_primitives=256) out;\n"
585              << "\n"
586              << "layout (location=0) out vec4 outColor[];\n"
587              << "\n"
588              << "void main ()\n"
589              << "{\n"
590              << "    gl_PrimitiveCountNV = 2u;\n"
591              << "\n"
592              << "    const uint vertex    = gl_LocalInvocationIndex;\n"
593              << "    const uint primitive = gl_LocalInvocationIndex;\n"
594              << "\n"
595              << "    const vec4 topLeft      = vec4(-1.0, -1.0, 0.0, 1.0);\n"
596              << "    const vec4 botLeft      = vec4(-1.0,  1.0, 0.0, 1.0);\n"
597              << "    const vec4 topRight     = vec4( 1.0, -1.0, 0.0, 1.0);\n"
598              << "    const vec4 botRight     = vec4( 1.0,  1.0, 0.0, 1.0);\n"
599              << "    const vec4 positions[4] = vec4[](topLeft, botLeft, topRight, botRight);\n"
600              << "\n"
601              // Green changes according to the width.
602              // Blue changes according to the height.
603              // Value 0 at the center of the first pixel and value 1 at the center of the last pixel.
604              << "    const float width      = " << extent.width << ";\n"
605              << "    const float height     = " << extent.height << ";\n"
606              << "    const float halfWidth  = (1.0 / (width - 1.0)) / 2.0;\n"
607              << "    const float halfHeight = (1.0 / (height - 1.0)) / 2.0;\n"
608              << "    const float minGreen   = -halfWidth;\n"
609              << "    const float maxGreen   = 1.0+halfWidth;\n"
610              << "    const float minBlue    = -halfHeight;\n"
611              << "    const float maxBlue    = 1.0+halfHeight;\n"
612              << "    const vec4  colors[4]  = vec4[](\n"
613              << "        vec4(0, minGreen, minBlue, 1.0),\n"
614              << "        vec4(0, minGreen, maxBlue, 1.0),\n"
615              << "        vec4(0, maxGreen, minBlue, 1.0),\n"
616              << "        vec4(0, maxGreen, maxBlue, 1.0)\n"
617              << "    );\n"
618              << "\n"
619              << "    const uint indices[6] = uint[](0, 1, 2, 1, 3, 2);\n"
620              << "\n"
621              << "    if (vertex < 4u)\n"
622              << "    {\n"
623              << "        gl_MeshVerticesNV[vertex].gl_Position = positions[vertex];\n"
624              << "        outColor[vertex] = colors[vertex];\n"
625              << "    }\n"
626              << "    if (primitive < 2u)\n"
627              << "    {\n"
628              << "        for (uint i = 0; i < 3; ++i) {\n"
629              << "            const uint arrayPos = 3u * primitive + i;\n"
630              << "            gl_PrimitiveIndicesNV[arrayPos] = indices[arrayPos];\n"
631              << "        }\n"
632              << "    }\n"
633              << "}\n";
634         ;
635         programCollection.glslSources.add("mesh") << glu::MeshSource(mesh.str());
636     }
637     else
638     {
639         const int shadingRateVal = getSPVShadingRateValue(fragmentSize.get());
640         DE_ASSERT(shadingRateVal != 0);
641 
642         // The following shader is largely equivalent to the GLSL below if it was accepted by glslang.
643 #if 0
644 #version 450
645 #extension GL_NV_mesh_shader : enable
646 
647         layout(local_size_x=4) in;
648         layout(triangles) out;
649         layout(max_vertices=256, max_primitives=256) out;
650 
651         layout (location=0) out vec4 outColor[];
652 
653         perprimitiveNV out gl_MeshPerPrimitiveNV {
654             int gl_PrimitiveShadingRateEXT;
655         } gl_MeshPrimitivesNV[];
656 
657         void main ()
658         {
659             gl_PrimitiveCountNV = 2u;
660 
661             const uint vertex    = gl_LocalInvocationIndex;
662             const uint primitive = gl_LocalInvocationIndex;
663 
664             const vec4 topLeft      = vec4(-1.0, -1.0, 0.0, 1.0);
665             const vec4 botLeft      = vec4(-1.0,  1.0, 0.0, 1.0);
666             const vec4 topRight     = vec4( 1.0, -1.0, 0.0, 1.0);
667             const vec4 botRight     = vec4( 1.0,  1.0, 0.0, 1.0);
668             const vec4 positions[4] = vec4[](topLeft, botLeft, topRight, botRight);
669 
670             const float width      = IMAGE_WIDTH;
671             const float height     = IMAGE_HEIGHT;
672             const float halfWidth  = (1.0 / (width - 1.0)) / 2.0;
673             const float halfHeight = (1.0 / (height - 1.0)) / 2.0;
674             const float minGreen   = -halfWidth;
675             const float maxGreen   = 1.0+halfWidth;
676             const float minBlue    = -halfHeight;
677             const float maxBlue    = 1.0+halfHeight;
678             const vec4  colors[4]  = vec4[](
679                 vec4(0, minGreen, minBlue, 1.0),
680                 vec4(0, minGreen, maxBlue, 1.0),
681                 vec4(0, maxGreen, minBlue, 1.0),
682                 vec4(0, maxGreen, maxBlue, 1.0)
683             );
684 
685             const uint indices[6] = uint[](0, 1, 2, 1, 3, 2);
686 
687             if (vertex < 4u)
688             {
689                 gl_MeshVerticesNV[vertex].gl_Position = positions[vertex];
690                 outColor[vertex] = colors[vertex];
691             }
692             if (primitive < 2u)
693             {
694                 gl_MeshPrimitivesNV[primitive].gl_PrimitiveShadingRateEXT = SHADING_RATE;
695                 for (uint i = 0; i < 3; ++i)
696                 {
697                     const uint arrayPos = 3u * primitive + i;
698                     gl_PrimitiveIndicesNV[arrayPos] = indices[arrayPos];
699                 }
700             }
701         }
702 #endif
703 
704 #undef SPV_PRECOMPUTED_CONSTANTS
705         std::ostringstream meshSPV;
706         meshSPV
707 
708             << "; SPIR-V\n"
709             << "; Version: 1.0\n"
710             << "; Generator: Khronos Glslang Reference Front End; 10\n"
711             << "; Bound: 145\n"
712             << "; Schema: 0\n"
713             << "               OpCapability MeshShadingNV\n"
714             << "               OpCapability FragmentShadingRateKHR\n" // Added.
715             << "               OpExtension \"SPV_NV_mesh_shader\"\n"
716             << "               OpExtension \"SPV_KHR_fragment_shading_rate\"\n" // Added.
717             << "          %1 = OpExtInstImport \"GLSL.std.450\"\n"
718             << "               OpMemoryModel Logical GLSL450\n"
719             << "               OpEntryPoint MeshNV %4 \"main\" %8 %13 %74 %93 %106 %129\n"
720             << "               OpExecutionMode %4 LocalSize 4 1 1\n"
721             << "               OpExecutionMode %4 OutputVertices 256\n"
722             << "               OpExecutionMode %4 OutputPrimitivesNV 256\n"
723             << "               OpExecutionMode %4 OutputTrianglesNV\n"
724             << "               OpDecorate %8 BuiltIn PrimitiveCountNV\n"
725             << "               OpDecorate %13 BuiltIn LocalInvocationIndex\n"
726             // These will be actual constants.
727             //<< "               OpDecorate %21 SpecId 0\n"
728             //<< "               OpDecorate %27 SpecId 1\n"
729             << "               OpMemberDecorate %70 0 BuiltIn Position\n"
730             << "               OpMemberDecorate %70 1 BuiltIn PointSize\n"
731             << "               OpMemberDecorate %70 2 BuiltIn ClipDistance\n"
732             << "               OpMemberDecorate %70 3 BuiltIn CullDistance\n"
733             << "               OpMemberDecorate %70 4 PerViewNV\n"
734             << "               OpMemberDecorate %70 4 BuiltIn PositionPerViewNV\n"
735             << "               OpMemberDecorate %70 5 PerViewNV\n"
736             << "               OpMemberDecorate %70 5 BuiltIn ClipDistancePerViewNV\n"
737             << "               OpMemberDecorate %70 6 PerViewNV\n"
738             << "               OpMemberDecorate %70 6 BuiltIn CullDistancePerViewNV\n"
739             << "               OpDecorate %70 Block\n"
740             << "               OpDecorate %93 Location 0\n"
741             << "               OpMemberDecorate %103 0 PerPrimitiveNV\n"
742             << "               OpMemberDecorate %103 0 BuiltIn PrimitiveShadingRateKHR\n" // Replaced PrimitiveID.
743             << "               OpDecorate %103 Block\n"
744             << "               OpDecorate %129 BuiltIn PrimitiveIndicesNV\n"
745             << "               OpDecorate %144 BuiltIn WorkgroupSize\n"
746             << "          %2 = OpTypeVoid\n"
747             << "          %3 = OpTypeFunction %2\n"
748             << "          %6 = OpTypeInt 32 0\n"
749             << "          %7 = OpTypePointer Output %6\n"
750             << "          %8 = OpVariable %7 Output\n"
751             << "          %9 = OpConstant %6 2\n"
752             << "         %10 = OpTypePointer Function %6\n"
753             << "         %12 = OpTypePointer Input %6\n"
754             << "         %13 = OpVariable %12 Input\n"
755             << "         %17 = OpTypeFloat 32\n"
756             << "         %18 = OpTypePointer Function %17\n"
757             << "         %20 = OpConstant %17 1\n"
758             << "         %21 = OpConstant %17 " << extent.width << "\n" // Made constant instead of spec constant.
759             << "         %24 = OpConstant %17 2\n"
760             << "         %27 = OpConstant %17 " << extent.height << "\n" // Made constant instead of spec constant.
761             << "         %43 = OpTypeVector %17 4\n"
762             << "         %44 = OpConstant %6 4\n"
763             << "         %45 = OpTypeArray %43 %44\n"
764             << "         %46 = OpTypePointer Function %45\n"
765             << "         %48 = OpConstant %17 0\n"
766             << "         %63 = OpTypeBool\n"
767             << "         %67 = OpConstant %6 1\n"
768             << "         %68 = OpTypeArray %17 %67\n"
769             << "         %69 = OpTypeArray %68 %44\n"
770             << "         %70 = OpTypeStruct %43 %17 %68 %68 %45 %69 %69\n"
771             << "         %71 = OpConstant %6 256\n"
772             << "         %72 = OpTypeArray %70 %71\n"
773             << "         %73 = OpTypePointer Output %72\n"
774             << "         %74 = OpVariable %73 Output\n"
775             << "         %76 = OpTypeInt 32 1\n"
776             << "         %77 = OpConstant %76 0\n"
777             << "         %78 = OpConstant %17 -1\n"
778             << "         %79 = OpConstantComposite %43 %78 %78 %48 %20\n"
779             << "         %80 = OpConstantComposite %43 %78 %20 %48 %20\n"
780             << "         %81 = OpConstantComposite %43 %20 %78 %48 %20\n"
781             << "         %82 = OpConstantComposite %43 %20 %20 %48 %20\n"
782             << "         %83 = OpConstantComposite %45 %79 %80 %81 %82\n"
783             << "         %86 = OpTypePointer Function %43\n"
784             << "         %89 = OpTypePointer Output %43\n"
785             << "         %91 = OpTypeArray %43 %71\n"
786             << "         %92 = OpTypePointer Output %91\n"
787             << "         %93 = OpVariable %92 Output\n"
788             << "        %103 = OpTypeStruct %76\n"
789             << "        %104 = OpTypeArray %103 %71\n"
790             << "        %105 = OpTypePointer Output %104\n"
791             << "        %106 = OpVariable %105 Output\n"
792             << "        %108 = OpConstant %76 " << shadingRateVal << "\n" // Used mask value here.
793             << "        %109 = OpTypePointer Output %76\n"
794             << "        %112 = OpConstant %6 0\n"
795             << "        %119 = OpConstant %6 3\n"
796             << "        %126 = OpConstant %6 768\n"
797             << "        %127 = OpTypeArray %6 %126\n"
798             << "        %128 = OpTypePointer Output %127\n"
799             << "        %129 = OpVariable %128 Output\n"
800             << "        %131 = OpConstant %6 6\n"
801             << "        %132 = OpTypeArray %6 %131\n"
802             << "        %133 = OpConstantComposite %132 %112 %67 %9 %67 %119 %9\n"
803             << "        %135 = OpTypePointer Function %132\n"
804             << "        %141 = OpConstant %76 1\n"
805             << "        %143 = OpTypeVector %6 3\n"
806             << "        %144 = OpConstantComposite %143 %44 %67 %67\n"
807             << "          %4 = OpFunction %2 None %3\n"
808             << "          %5 = OpLabel\n"
809             << "         %11 = OpVariable %10 Function\n"
810             << "         %15 = OpVariable %10 Function\n"
811             << "         %19 = OpVariable %18 Function\n"
812             << "         %26 = OpVariable %18 Function\n"
813             << "         %31 = OpVariable %18 Function\n"
814             << "         %34 = OpVariable %18 Function\n"
815             << "         %37 = OpVariable %18 Function\n"
816             << "         %40 = OpVariable %18 Function\n"
817             << "         %47 = OpVariable %46 Function\n"
818             << "         %85 = OpVariable %46 Function\n"
819             << "        %111 = OpVariable %10 Function\n"
820             << "        %121 = OpVariable %10 Function\n"
821             << "        %136 = OpVariable %135 Function\n"
822             << "               OpStore %8 %9\n"
823             << "         %14 = OpLoad %6 %13\n"
824             << "               OpStore %11 %14\n"
825             << "         %16 = OpLoad %6 %13\n"
826             << "               OpStore %15 %16\n"
827             << "         %22 = OpFSub %17 %21 %20\n"
828             << "         %23 = OpFDiv %17 %20 %22\n"
829             << "         %25 = OpFDiv %17 %23 %24\n"
830             << "               OpStore %19 %25\n"
831             << "         %28 = OpFSub %17 %27 %20\n"
832             << "         %29 = OpFDiv %17 %20 %28\n"
833             << "         %30 = OpFDiv %17 %29 %24\n"
834             << "               OpStore %26 %30\n"
835             << "         %32 = OpLoad %17 %19\n"
836             << "         %33 = OpFNegate %17 %32\n"
837             << "               OpStore %31 %33\n"
838             << "         %35 = OpLoad %17 %26\n"
839             << "         %36 = OpFNegate %17 %35\n"
840             << "               OpStore %34 %36\n"
841             << "         %38 = OpLoad %17 %19\n"
842             << "         %39 = OpFAdd %17 %20 %38\n"
843             << "               OpStore %37 %39\n"
844             << "         %41 = OpLoad %17 %26\n"
845             << "         %42 = OpFAdd %17 %20 %41\n"
846             << "               OpStore %40 %42\n"
847             << "         %49 = OpLoad %17 %31\n"
848             << "         %50 = OpLoad %17 %34\n"
849             << "         %51 = OpCompositeConstruct %43 %48 %49 %50 %20\n"
850             << "         %52 = OpLoad %17 %31\n"
851             << "         %53 = OpLoad %17 %40\n"
852             << "         %54 = OpCompositeConstruct %43 %48 %52 %53 %20\n"
853             << "         %55 = OpLoad %17 %37\n"
854             << "         %56 = OpLoad %17 %34\n"
855             << "         %57 = OpCompositeConstruct %43 %48 %55 %56 %20\n"
856             << "         %58 = OpLoad %17 %37\n"
857             << "         %59 = OpLoad %17 %40\n"
858             << "         %60 = OpCompositeConstruct %43 %48 %58 %59 %20\n"
859             << "         %61 = OpCompositeConstruct %45 %51 %54 %57 %60\n"
860             << "               OpStore %47 %61\n"
861             << "         %62 = OpLoad %6 %11\n"
862             << "         %64 = OpULessThan %63 %62 %44\n"
863             << "               OpSelectionMerge %66 None\n"
864             << "               OpBranchConditional %64 %65 %66\n"
865             << "         %65 = OpLabel\n"
866             << "         %75 = OpLoad %6 %11\n"
867             << "         %84 = OpLoad %6 %11\n"
868             << "               OpStore %85 %83\n"
869             << "         %87 = OpAccessChain %86 %85 %84\n"
870             << "         %88 = OpLoad %43 %87\n"
871             << "         %90 = OpAccessChain %89 %74 %75 %77\n"
872             << "               OpStore %90 %88\n"
873             << "         %94 = OpLoad %6 %11\n"
874             << "         %95 = OpLoad %6 %11\n"
875             << "         %96 = OpAccessChain %86 %47 %95\n"
876             << "         %97 = OpLoad %43 %96\n"
877             << "         %98 = OpAccessChain %89 %93 %94\n"
878             << "               OpStore %98 %97\n"
879             << "               OpBranch %66\n"
880             << "         %66 = OpLabel\n"
881             << "         %99 = OpLoad %6 %15\n"
882             << "        %100 = OpULessThan %63 %99 %9\n"
883             << "               OpSelectionMerge %102 None\n"
884             << "               OpBranchConditional %100 %101 %102\n"
885             << "        %101 = OpLabel\n"
886             << "        %107 = OpLoad %6 %15\n"
887             << "        %110 = OpAccessChain %109 %106 %107 %77\n"
888             << "               OpStore %110 %108\n"
889             << "               OpStore %111 %112\n"
890             << "               OpBranch %113\n"
891             << "        %113 = OpLabel\n"
892             << "               OpLoopMerge %115 %116 None\n"
893             << "               OpBranch %117\n"
894             << "        %117 = OpLabel\n"
895             << "        %118 = OpLoad %6 %111\n"
896             << "        %120 = OpULessThan %63 %118 %119\n"
897             << "               OpBranchConditional %120 %114 %115\n"
898             << "        %114 = OpLabel\n"
899             << "        %122 = OpLoad %6 %15\n"
900             << "        %123 = OpIMul %6 %119 %122\n"
901             << "        %124 = OpLoad %6 %111\n"
902             << "        %125 = OpIAdd %6 %123 %124\n"
903             << "               OpStore %121 %125\n"
904             << "        %130 = OpLoad %6 %121\n"
905             << "        %134 = OpLoad %6 %121\n"
906             << "               OpStore %136 %133\n"
907             << "        %137 = OpAccessChain %10 %136 %134\n"
908             << "        %138 = OpLoad %6 %137\n"
909             << "        %139 = OpAccessChain %7 %129 %130\n"
910             << "               OpStore %139 %138\n"
911             << "               OpBranch %116\n"
912             << "        %116 = OpLabel\n"
913             << "        %140 = OpLoad %6 %111\n"
914             << "        %142 = OpIAdd %6 %140 %141\n"
915             << "               OpStore %111 %142\n"
916             << "               OpBranch %113\n"
917             << "        %115 = OpLabel\n"
918             << "               OpBranch %102\n"
919             << "        %102 = OpLabel\n"
920             << "               OpReturn\n"
921             << "               OpFunctionEnd\n";
922         programCollection.spirvAsmSources.add("mesh") << meshSPV.str();
923     }
924 }
925 
coordColorFormat(int x,int y,const tcu::Vec4 & color)926 std::string coordColorFormat(int x, int y, const tcu::Vec4 &color)
927 {
928     std::ostringstream msg;
929     msg << "[" << x << ", " << y << "]=(" << color.x() << ", " << color.y() << ", " << color.z() << ", " << color.w()
930         << ")";
931     return msg.str();
932 }
933 
testFullscreenGradient(Context & context,tcu::Maybe<FragmentSize> fragmentSize)934 tcu::TestStatus testFullscreenGradient(Context &context, tcu::Maybe<FragmentSize> fragmentSize)
935 {
936     const auto &vkd                = context.getDeviceInterface();
937     const auto device              = context.getDevice();
938     auto &alloc                    = context.getDefaultAllocator();
939     const auto qIndex              = context.getUniversalQueueFamilyIndex();
940     const auto queue               = context.getUniversalQueue();
941     const auto useFragmentSize     = static_cast<bool>(fragmentSize);
942     const auto defaultFragmentSize = FragmentSize::SIZE_1X1;
943     const auto rateSize            = getShadingRateSize(useFragmentSize ? fragmentSize.get() : defaultFragmentSize);
944 
945     // Color buffer.
946     const auto colorBufferFormat = VK_FORMAT_R8G8B8A8_UNORM;
947     const auto colorBufferExtent =
948         makeExtent3D(256u, 256u, 1u); // Big enough for a detailed gradient, small enough to get unique colors.
949     const auto colorBufferUsage = (VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT | VK_IMAGE_USAGE_TRANSFER_SRC_BIT);
950 
951     const VkImageCreateInfo colorBufferInfo = {
952         VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // VkStructureType sType;
953         nullptr,                             // const void* pNext;
954         0u,                                  // VkImageCreateFlags flags;
955         VK_IMAGE_TYPE_2D,                    // VkImageType imageType;
956         colorBufferFormat,                   // VkFormat format;
957         colorBufferExtent,                   // VkExtent3D extent;
958         1u,                                  // uint32_t mipLevels;
959         1u,                                  // uint32_t arrayLayers;
960         VK_SAMPLE_COUNT_1_BIT,               // VkSampleCountFlagBits samples;
961         VK_IMAGE_TILING_OPTIMAL,             // VkImageTiling tiling;
962         colorBufferUsage,                    // VkImageUsageFlags usage;
963         VK_SHARING_MODE_EXCLUSIVE,           // VkSharingMode sharingMode;
964         0u,                                  // uint32_t queueFamilyIndexCount;
965         nullptr,                             // const uint32_t* pQueueFamilyIndices;
966         VK_IMAGE_LAYOUT_UNDEFINED,           // VkImageLayout initialLayout;
967     };
968     ImageWithMemory colorBuffer(vkd, device, alloc, colorBufferInfo, MemoryRequirement::Any);
969 
970     const auto colorSRR = makeImageSubresourceRange(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 1u, 0u, 1u);
971     const auto colorBufferView =
972         makeImageView(vkd, device, colorBuffer.get(), VK_IMAGE_VIEW_TYPE_2D, colorBufferFormat, colorSRR);
973 
974     // Render pass.
975     const auto renderPass = makeRenderPass(vkd, device, colorBufferFormat);
976 
977     // Framebuffer.
978     const auto framebuffer = makeFramebuffer(vkd, device, renderPass.get(), colorBufferView.get(),
979                                              colorBufferExtent.width, colorBufferExtent.height);
980 
981     // Set layout.
982     DescriptorSetLayoutBuilder layoutBuilder;
983     const auto setLayout = layoutBuilder.build(vkd, device);
984 
985     // Pipeline layout.
986     const auto pipelineLayout = makePipelineLayout(vkd, device, setLayout.get());
987 
988     // Shader modules.
989     Move<VkShaderModule> taskModule;
990     const auto &binaries = context.getBinaryCollection();
991 
992     const auto meshModule = createShaderModule(vkd, device, binaries.get("mesh"), 0u);
993     const auto fragModule = createShaderModule(vkd, device, binaries.get("frag"), 0u);
994 
995     using ShadingRateInfoPtr = de::MovePtr<VkPipelineFragmentShadingRateStateCreateInfoKHR>;
996     ShadingRateInfoPtr pNext;
997     if (useFragmentSize)
998     {
999         pNext  = ShadingRateInfoPtr(new VkPipelineFragmentShadingRateStateCreateInfoKHR);
1000         *pNext = initVulkanStructure();
1001 
1002         pNext->fragmentSize = getShadingRateSize(
1003             FragmentSize::SIZE_1X1); // 1x1 will not be used as the primitive rate in tests with fragment size.
1004         pNext->combinerOps[0] = VK_FRAGMENT_SHADING_RATE_COMBINER_OP_REPLACE_KHR;
1005         pNext->combinerOps[1] = VK_FRAGMENT_SHADING_RATE_COMBINER_OP_KEEP_KHR;
1006     }
1007 
1008     // Graphics pipeline.
1009     std::vector<VkViewport> viewports(1u, makeViewport(colorBufferExtent));
1010     std::vector<VkRect2D> scissors(1u, makeRect2D(colorBufferExtent));
1011     const auto pipeline = makeGraphicsPipeline(vkd, device, pipelineLayout.get(), taskModule.get(), meshModule.get(),
1012                                                fragModule.get(), renderPass.get(), viewports, scissors, 0u, nullptr,
1013                                                nullptr, nullptr, nullptr, nullptr, 0u, pNext.get());
1014 
1015     // Command pool and buffer.
1016     const auto cmdPool      = makeCommandPool(vkd, device, qIndex);
1017     const auto cmdBufferPtr = allocateCommandBuffer(vkd, device, cmdPool.get(), VK_COMMAND_BUFFER_LEVEL_PRIMARY);
1018     const auto cmdBuffer    = cmdBufferPtr.get();
1019 
1020     // Output buffer.
1021     const auto tcuFormat      = mapVkFormat(colorBufferFormat);
1022     const auto outBufferSize  = static_cast<VkDeviceSize>(static_cast<uint32_t>(tcu::getPixelSize(tcuFormat)) *
1023                                                          colorBufferExtent.width * colorBufferExtent.height);
1024     const auto outBufferUsage = VK_BUFFER_USAGE_TRANSFER_DST_BIT;
1025     const auto outBufferInfo  = makeBufferCreateInfo(outBufferSize, outBufferUsage);
1026     BufferWithMemory outBuffer(vkd, device, alloc, outBufferInfo, MemoryRequirement::HostVisible);
1027     auto &outBufferAlloc = outBuffer.getAllocation();
1028     void *outBufferData  = outBufferAlloc.getHostPtr();
1029 
1030     // Draw triangles.
1031     beginCommandBuffer(vkd, cmdBuffer);
1032     beginRenderPass(vkd, cmdBuffer, renderPass.get(), framebuffer.get(), scissors.at(0),
1033                     tcu::Vec4(0.0f, 0.0f, 0.0f, 1.0f) /*clear color*/);
1034     vkd.cmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, pipeline.get());
1035     vkd.cmdDrawMeshTasksNV(cmdBuffer, 1u, 0u);
1036     endRenderPass(vkd, cmdBuffer);
1037 
1038     // Copy color buffer to output buffer.
1039     const tcu::IVec3 imageDim(static_cast<int>(colorBufferExtent.width), static_cast<int>(colorBufferExtent.height),
1040                               static_cast<int>(colorBufferExtent.depth));
1041     const tcu::IVec2 imageSize(imageDim.x(), imageDim.y());
1042 
1043     copyImageToBuffer(vkd, cmdBuffer, colorBuffer.get(), outBuffer.get(), imageSize);
1044     endCommandBuffer(vkd, cmdBuffer);
1045     submitCommandsAndWait(vkd, device, queue, cmdBuffer);
1046 
1047     // Invalidate alloc.
1048     invalidateAlloc(vkd, device, outBufferAlloc);
1049     tcu::ConstPixelBufferAccess outPixels(tcuFormat, imageDim, outBufferData);
1050 
1051     // Create reference image.
1052     tcu::TextureLevel refLevel(tcuFormat, imageDim.x(), imageDim.y(), imageDim.z());
1053     tcu::PixelBufferAccess refAccess(refLevel);
1054     for (int y = 0; y < imageDim.y(); ++y)
1055         for (int x = 0; x < imageDim.x(); ++x)
1056         {
1057             const tcu::IVec4 color(0, x, y, 255);
1058             refAccess.setPixel(color, x, y);
1059         }
1060 
1061     const tcu::TextureFormat maskFormat(tcu::TextureFormat::RGBA, tcu::TextureFormat::UNORM_INT8);
1062     tcu::TextureLevel errorMask(maskFormat, imageDim.x(), imageDim.y(), imageDim.z());
1063     tcu::PixelBufferAccess errorAccess(errorMask);
1064     const tcu::Vec4 green(0.0f, 1.0f, 0.0f, 1.0f);
1065     const tcu::Vec4 red(1.0f, 0.0f, 0.0f, 1.0f);
1066     auto &log = context.getTestContext().getLog();
1067 
1068     // Each block needs to have the same color and be equal to one of the pixel colors of that block in the reference image.
1069     const auto blockWidth  = static_cast<int>(rateSize.width);
1070     const auto blockHeight = static_cast<int>(rateSize.height);
1071 
1072     tcu::clear(errorAccess, green);
1073     bool globalFail = false;
1074 
1075     for (int y = 0; y < imageDim.y() / blockHeight; ++y)
1076         for (int x = 0; x < imageDim.x() / blockWidth; ++x)
1077         {
1078             bool blockFail = false;
1079             std::vector<tcu::Vec4> candidates;
1080 
1081             candidates.reserve(rateSize.width * rateSize.height);
1082 
1083             const auto cornerY     = y * blockHeight;
1084             const auto cornerX     = x * blockWidth;
1085             const auto cornerColor = outPixels.getPixel(cornerX, cornerY);
1086 
1087             for (int blockY = 0; blockY < blockHeight; ++blockY)
1088                 for (int blockX = 0; blockX < blockWidth; ++blockX)
1089                 {
1090                     const auto absY     = cornerY + blockY;
1091                     const auto absX     = cornerX + blockX;
1092                     const auto resColor = outPixels.getPixel(absX, absY);
1093 
1094                     candidates.push_back(refAccess.getPixel(absX, absY));
1095 
1096                     if (cornerColor != resColor)
1097                     {
1098                         std::ostringstream msg;
1099                         msg << "Block not uniform: " << coordColorFormat(cornerX, cornerY, cornerColor) << " vs "
1100                             << coordColorFormat(absX, absY, resColor);
1101                         log << tcu::TestLog::Message << msg.str() << tcu::TestLog::EndMessage;
1102 
1103                         blockFail = true;
1104                     }
1105                 }
1106 
1107             if (!de::contains(begin(candidates), end(candidates), cornerColor))
1108             {
1109                 std::ostringstream msg;
1110                 msg << "Block color does not match any reference color at [" << cornerX << ", " << cornerY << "]";
1111                 log << tcu::TestLog::Message << msg.str() << tcu::TestLog::EndMessage;
1112                 blockFail = true;
1113             }
1114 
1115             if (blockFail)
1116             {
1117                 const auto blockAccess = tcu::getSubregion(errorAccess, cornerX, cornerY, blockWidth, blockHeight);
1118                 tcu::clear(blockAccess, red);
1119                 globalFail = true;
1120             }
1121         }
1122 
1123     if (globalFail)
1124     {
1125         log << tcu::TestLog::Image("Result", "", outPixels);
1126         log << tcu::TestLog::Image("Reference", "", refAccess);
1127         log << tcu::TestLog::Image("ErrorMask", "", errorAccess);
1128 
1129         TCU_FAIL("Color mismatch; check log for more details");
1130     }
1131 
1132     return tcu::TestStatus::pass("Pass");
1133 }
1134 
1135 } // namespace
1136 
createMeshShaderSmokeTests(tcu::TestContext & testCtx)1137 tcu::TestCaseGroup *createMeshShaderSmokeTests(tcu::TestContext &testCtx)
1138 {
1139     GroupPtr smokeTests(new tcu::TestCaseGroup(testCtx, "smoke"));
1140 
1141     smokeTests->addChild(new MeshOnlyTriangleCase(testCtx, "mesh_shader_triangle"));
1142     smokeTests->addChild(new MeshTaskTriangleCase(testCtx, "mesh_task_shader_triangle"));
1143     smokeTests->addChild(new TaskOnlyTriangleCase(testCtx, "task_only_shader_triangle"));
1144 
1145     addFunctionCaseWithPrograms(smokeTests.get(), "fullscreen_gradient", checkMeshSupport, initGradientPrograms,
1146                                 testFullscreenGradient, tcu::nothing<FragmentSize>());
1147     addFunctionCaseWithPrograms(smokeTests.get(), "fullscreen_gradient_fs2x2", checkMeshSupport, initGradientPrograms,
1148                                 testFullscreenGradient, tcu::just(FragmentSize::SIZE_2X2));
1149     addFunctionCaseWithPrograms(smokeTests.get(), "fullscreen_gradient_fs2x1", checkMeshSupport, initGradientPrograms,
1150                                 testFullscreenGradient, tcu::just(FragmentSize::SIZE_2X1));
1151 
1152     return smokeTests.release();
1153 }
1154 
1155 } // namespace MeshShader
1156 } // namespace vkt
1157