1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2015 The Khronos Group Inc.
6  * Copyright (c) 2023 ARM Ltd.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  *//*!
21  * \file
22  * \brief Test cases for VK_KHR_shader_expect_assume.
23  *        Ensure being working the OpAssumeTrueKHR/OpExpectKHR OpCode.
24  *//*--------------------------------------------------------------------*/
25 
26 #include "vktShaderExpectAssumeTests.hpp"
27 #include "vktShaderExecutor.hpp"
28 #include "vktTestGroupUtil.hpp"
29 
30 #include "tcuStringTemplate.hpp"
31 
32 #include "vkBuilderUtil.hpp"
33 #include "vkCmdUtil.hpp"
34 #include "vkMemUtil.hpp"
35 #include "vkObjUtil.hpp"
36 #include "vkQueryUtil.hpp"
37 #include "vkRefUtil.hpp"
38 #include "vkTypeUtil.hpp"
39 
40 #include "tcuResultCollector.hpp"
41 
42 #include "deArrayUtil.hpp"
43 #include "deSharedPtr.hpp"
44 #include "deStringUtil.hpp"
45 
46 #include <cassert>
47 #include <string>
48 
49 namespace vkt
50 {
51 namespace shaderexecutor
52 {
53 
54 namespace
55 {
56 
57 using namespace vk;
58 constexpr uint32_t kNumElements           = 32;
59 constexpr VkFormat kColorAttachmentFormat = VK_FORMAT_R32G32_UINT;
60 
61 enum class OpType
62 {
63     Expect = 0,
64     Assume
65 };
66 
67 enum class DataClass
68 {
69     Constant = 0,
70     SpecializationConstant,
71     PushConstant,
72     StorageBuffer,
73 };
74 
75 enum class DataType
76 {
77     Bool = 0,
78     Int8,
79     Int16,
80     Int32,
81     Int64
82 };
83 
84 struct TestParam
85 {
86     OpType opType;
87     DataClass dataClass;
88     DataType dataType;
89     uint32_t dataChannelCount;
90     VkShaderStageFlagBits shaderType;
91     bool wrongExpectation;
92     std::string testName;
93 };
94 
95 class ShaderExpectAssumeTestInstance : public TestInstance
96 {
97 public:
ShaderExpectAssumeTestInstance(Context & context,const TestParam & testParam)98     ShaderExpectAssumeTestInstance(Context &context, const TestParam &testParam)
99         : TestInstance(context)
100         , m_testParam(testParam)
101         , m_vk(m_context.getDeviceInterface())
102     {
103         initialize();
104     }
105 
iterate(void)106     virtual tcu::TestStatus iterate(void)
107     {
108         if (m_testParam.shaderType == VK_SHADER_STAGE_COMPUTE_BIT)
109         {
110             dispatch();
111         }
112         else
113         {
114             render();
115         }
116 
117         const uint32_t *outputData = reinterpret_cast<uint32_t *>(m_outputAlloc->getHostPtr());
118         return validateOutput(outputData);
119     }
120 
121 private:
validateOutput(const uint32_t * outputData)122     tcu::TestStatus validateOutput(const uint32_t *outputData)
123     {
124         for (uint32_t i = 0; i < kNumElements; i++)
125         {
126             // (gl_GlobalInvocationID.x, verification result)
127             if (outputData[i * 2] != i || outputData[i * 2 + 1] != 1)
128             {
129                 return tcu::TestStatus::fail("Result comparison failed");
130             }
131         }
132         return tcu::TestStatus::pass("Pass");
133     }
134 
initialize()135     void initialize()
136     {
137         generateCmdBuffer();
138         if (m_testParam.shaderType == VK_SHADER_STAGE_COMPUTE_BIT)
139         {
140             generateStorageBuffers();
141             generateComputePipeline();
142         }
143         else
144         {
145             generateAttachments();
146             generateVertexBuffer();
147             generateStorageBuffers();
148             generateGraphicsPipeline();
149         }
150     }
151 
generateCmdBuffer()152     void generateCmdBuffer()
153     {
154         const VkDevice device = m_context.getDevice();
155 
156         m_cmdPool   = createCommandPool(m_vk, device, VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT,
157                                         m_context.getUniversalQueueFamilyIndex());
158         m_cmdBuffer = allocateCommandBuffer(m_vk, device, *m_cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
159     }
160 
generateVertexBuffer()161     void generateVertexBuffer()
162     {
163         const VkDevice device           = m_context.getDevice();
164         const DeviceInterface &vk       = m_context.getDeviceInterface();
165         const uint32_t queueFamilyIndex = m_context.getUniversalQueueFamilyIndex();
166         Allocator &memAlloc             = m_context.getDefaultAllocator();
167         std::vector<tcu::Vec2> vbo;
168         // _____
169         // |  /
170         // | /
171         // |/
172         vbo.emplace_back(tcu::Vec2(-1, -1));
173         vbo.emplace_back(tcu::Vec2(1, 1));
174         vbo.emplace_back(tcu::Vec2(-1, 1));
175         //   /|
176         //  / |
177         // /__|
178         vbo.emplace_back(tcu::Vec2(-1, -1));
179         vbo.emplace_back(tcu::Vec2(1, -1));
180         vbo.emplace_back(tcu::Vec2(1, 1));
181 
182         const size_t dataSize               = vbo.size() * sizeof(tcu::Vec2);
183         const VkBufferCreateInfo bufferInfo = {
184             VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO, // VkStructureType sType;
185             nullptr,                              // const void* pNext;
186             0u,                                   // VkBufferCreateFlags flags;
187             dataSize,                             // VkDeviceSize size;
188             VK_BUFFER_USAGE_VERTEX_BUFFER_BIT,    // VkBufferUsageFlags usage;
189             VK_SHARING_MODE_EXCLUSIVE,            // VkSharingMode sharingMode;
190             1u,                                   // uint32_t queueFamilyCount;
191             &queueFamilyIndex                     // const uint32_t* pQueueFamilyIndices;
192         };
193         m_vertexBuffer = createBuffer(vk, device, &bufferInfo);
194         m_vertexAlloc =
195             memAlloc.allocate(getBufferMemoryRequirements(vk, device, *m_vertexBuffer), MemoryRequirement::HostVisible);
196 
197         void *vertexData = m_vertexAlloc->getHostPtr();
198 
199         VK_CHECK(vk.bindBufferMemory(device, *m_vertexBuffer, m_vertexAlloc->getMemory(), m_vertexAlloc->getOffset()));
200 
201         /* Load vertices into vertex buffer */
202         deMemcpy(vertexData, vbo.data(), dataSize);
203         flushAlloc(vk, device, *m_vertexAlloc);
204     }
205 
generateAttachments()206     void generateAttachments()
207     {
208         const VkDevice device = m_context.getDevice();
209         Allocator &allocator  = m_context.getDefaultAllocator();
210 
211         const VkImageUsageFlags imageUsage = VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT | VK_IMAGE_USAGE_TRANSFER_SRC_BIT;
212 
213         // Color Attachment
214         const VkImageCreateInfo imageInfo = {
215             VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // VkStructureType sType;
216             nullptr,                             // const void* pNext;
217             (VkImageCreateFlags)0,               // VkImageCreateFlags flags;
218             VK_IMAGE_TYPE_2D,                    // VkImageType imageType;
219             kColorAttachmentFormat,              // VkFormat format;
220             makeExtent3D(kNumElements, 1, 1),    // VkExtent3D extent;
221             1u,                                  // uint32_t mipLevels;
222             1u,                                  // uint32_t arrayLayers;
223             VK_SAMPLE_COUNT_1_BIT,               // VkSampleCountFlagBits samples;
224             VK_IMAGE_TILING_OPTIMAL,             // VkImageTiling tiling;
225             imageUsage,                          // VkImageUsageFlags usage;
226             VK_SHARING_MODE_EXCLUSIVE,           // VkSharingMode sharingMode;
227             0u,                                  // uint32_t queueFamilyIndexCount;
228             nullptr,                             // const uint32_t* pQueueFamilyIndices;
229             VK_IMAGE_LAYOUT_UNDEFINED,           // VkImageLayout initialLayout;
230         };
231 
232         const VkImageSubresourceRange imageSubresource =
233             makeImageSubresourceRange(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 1u, 0u, 1u);
234 
235         m_imageColor      = makeImage(m_vk, device, imageInfo);
236         m_imageColorAlloc = bindImage(m_vk, device, allocator, *m_imageColor, MemoryRequirement::Any);
237         m_imageColorView =
238             makeImageView(m_vk, device, *m_imageColor, VK_IMAGE_VIEW_TYPE_2D, kColorAttachmentFormat, imageSubresource);
239     }
240 
generateGraphicsPipeline()241     void generateGraphicsPipeline()
242     {
243         const VkDevice device = m_context.getDevice();
244         std::vector<VkDescriptorSetLayoutBinding> bindings;
245 
246         if (m_testParam.dataClass == DataClass::StorageBuffer)
247         {
248             VkDescriptorSetLayoutCreateFlags layoutCreateFlags = 0;
249 
250             bindings.emplace_back(VkDescriptorSetLayoutBinding{
251                 0,                                                       // binding
252                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,                       // descriptorType
253                 1,                                                       // descriptorCount
254                 static_cast<VkShaderStageFlags>(m_testParam.shaderType), // stageFlags
255                 nullptr,                                                 // pImmutableSamplers
256             });                                                          // input binding
257 
258             // Create a layout and allocate a descriptor set for it.
259             const VkDescriptorSetLayoutCreateInfo setLayoutCreateInfo = {
260                 vk::VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO, // sType
261                 nullptr,                                                 // pNext
262                 layoutCreateFlags,                                       // flags
263                 static_cast<uint32_t>(bindings.size()),                  // bindingCount
264                 bindings.data()                                          // pBindings
265             };
266 
267             m_descriptorSetLayout = vk::createDescriptorSetLayout(m_vk, device, &setLayoutCreateInfo);
268             m_pipelineLayout      = makePipelineLayout(m_vk, device, 1, &m_descriptorSetLayout.get(), 0, nullptr);
269         }
270         else if (m_testParam.dataClass == DataClass::PushConstant)
271         {
272             VkPushConstantRange pushConstant{static_cast<VkShaderStageFlags>(m_testParam.shaderType), 0,
273                                              sizeof(VkBool32)};
274             m_pipelineLayout = makePipelineLayout(m_vk, device, 0, nullptr, 1, &pushConstant);
275         }
276         else
277         {
278             m_pipelineLayout = makePipelineLayout(m_vk, device, 0, nullptr, 0, nullptr);
279         }
280 
281         Move<VkShaderModule> vertexModule =
282             createShaderModule(m_vk, device, m_context.getBinaryCollection().get("vert"), 0u);
283         Move<VkShaderModule> fragmentModule =
284             createShaderModule(m_vk, device, m_context.getBinaryCollection().get("frag"), 0u);
285 
286         const VkVertexInputBindingDescription vertexInputBindingDescription = {
287             0,                           // uint32_t binding;
288             sizeof(tcu::Vec2),           // uint32_t strideInBytes;
289             VK_VERTEX_INPUT_RATE_VERTEX, // VkVertexInputStepRate stepRate;
290         };
291 
292         const VkVertexInputAttributeDescription vertexInputAttributeDescription = {
293             0u,                      // uint32_t location;
294             0u,                      // uint32_t binding;
295             VK_FORMAT_R32G32_SFLOAT, // VkFormat format;
296             0u,                      // uint32_t offsetInBytes;
297         };
298 
299         const VkPipelineVertexInputStateCreateInfo vertexInputStateParams = {
300             VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO, // VkStructureType sType;
301             nullptr,                                                   // const void* pNext;
302             0,                                                         // VkPipelineVertexInputStateCreateFlags flags;
303             1u,                                                        // uint32_t bindingCount;
304             &vertexInputBindingDescription,   // const VkVertexInputBindingDescription* pVertexBindingDescriptions;
305             1u,                               // uint32_t attributeCount;
306             &vertexInputAttributeDescription, // const VkVertexInputAttributeDescription* pVertexAttributeDescriptions;
307         };
308 
309         const VkPipelineInputAssemblyStateCreateInfo pipelineInputAssemblyStateInfo = {
310             VK_STRUCTURE_TYPE_PIPELINE_INPUT_ASSEMBLY_STATE_CREATE_INFO, // VkStructureType sType;
311             nullptr,                                                     // const void* pNext;
312             (VkPipelineInputAssemblyStateCreateFlags)0, // VkPipelineInputAssemblyStateCreateFlags flags;
313             VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST,        // VkPrimitiveTopology topology;
314             VK_FALSE,                                   // VkBool32 primitiveRestartEnable;
315         };
316 
317         const VkViewport viewport{0, 0, static_cast<float>(kNumElements), 1, 0, 1};
318         const VkRect2D scissor{{0, 0}, {kNumElements, 1}};
319 
320         const VkPipelineViewportStateCreateInfo pipelineViewportStateInfo = {
321             VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO, // VkStructureType sType;
322             nullptr,                                               // const void* pNext;
323             (VkPipelineViewportStateCreateFlags)0,                 // VkPipelineViewportStateCreateFlags flags;
324             1u,                                                    // uint32_t viewportCount;
325             &viewport,                                             // const VkViewport* pViewports;
326             1u,                                                    // uint32_t scissorCount;
327             &scissor,                                              // const VkRect2D* pScissors;
328         };
329 
330         const VkPipelineRasterizationStateCreateInfo pipelineRasterizationStateInfo = {
331             VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_STATE_CREATE_INFO, // VkStructureType sType;
332             nullptr,                                                    // const void* pNext;
333             0u,                              // VkPipelineRasterizationStateCreateFlags flags;
334             VK_FALSE,                        // VkBool32 depthClampEnable;
335             VK_FALSE,                        // VkBool32 rasterizerDiscardEnable;
336             VK_POLYGON_MODE_FILL,            // VkPolygonMode polygonMode;
337             VK_CULL_MODE_NONE,               // VkCullModeFlags cullMode;
338             VK_FRONT_FACE_COUNTER_CLOCKWISE, // VkFrontFace frontFace;
339             VK_FALSE,                        // VkBool32 depthBiasEnable;
340             0.0f,                            // float depthBiasConstantFactor;
341             0.0f,                            // float depthBiasClamp;
342             0.0f,                            // float depthBiasSlopeFactor;
343             1.0f,                            // float lineWidth;
344         };
345 
346         const VkPipelineMultisampleStateCreateInfo pipelineMultisampleStateInfo = {
347             VK_STRUCTURE_TYPE_PIPELINE_MULTISAMPLE_STATE_CREATE_INFO, // VkStructureType sType;
348             nullptr,                                                  // const void* pNext;
349             0u,                                                       // VkPipelineMultisampleStateCreateFlags flags;
350             VK_SAMPLE_COUNT_1_BIT,                                    // VkSampleCountFlagBits rasterizationSamples;
351             VK_FALSE,                                                 // VkBool32 sampleShadingEnable;
352             1.0f,                                                     // float minSampleShading;
353             nullptr,                                                  // const VkSampleMask* pSampleMask;
354             VK_FALSE,                                                 // VkBool32 alphaToCoverageEnable;
355             VK_FALSE                                                  // VkBool32 alphaToOneEnable;
356         };
357 
358         std::vector<VkPipelineColorBlendAttachmentState> colorBlendAttachmentState(
359             1,
360             {
361                 false,                                                // VkBool32 blendEnable;
362                 VK_BLEND_FACTOR_ONE,                                  // VkBlend srcBlendColor;
363                 VK_BLEND_FACTOR_ONE,                                  // VkBlend destBlendColor;
364                 VK_BLEND_OP_ADD,                                      // VkBlendOp blendOpColor;
365                 VK_BLEND_FACTOR_ONE,                                  // VkBlend srcBlendAlpha;
366                 VK_BLEND_FACTOR_ONE,                                  // VkBlend destBlendAlpha;
367                 VK_BLEND_OP_ADD,                                      // VkBlendOp blendOpAlpha;
368                 (VK_COLOR_COMPONENT_R_BIT | VK_COLOR_COMPONENT_G_BIT) // VkChannelFlags channelWriteMask;
369             });
370 
371         const VkPipelineColorBlendStateCreateInfo pipelineColorBlendStateInfo = {
372             VK_STRUCTURE_TYPE_PIPELINE_COLOR_BLEND_STATE_CREATE_INFO, // VkStructureType sType;
373             nullptr,                                                  // const void* pNext;
374             /* always needed */
375             0,                                          // VkPipelineColorBlendStateCreateFlags flags;
376             false,                                      // VkBool32 logicOpEnable;
377             VK_LOGIC_OP_COPY,                           // VkLogicOp logicOp;
378             (uint32_t)colorBlendAttachmentState.size(), // uint32_t attachmentCount;
379             colorBlendAttachmentState.data(),           // const VkPipelineColorBlendAttachmentState* pAttachments;
380             {0.0f, 0.0f, 0.0f, 0.0f},                   // float blendConst[4];
381         };
382 
383         VkStencilOpState stencilOpState = {
384             VK_STENCIL_OP_ZERO,               // VkStencilOp failOp;
385             VK_STENCIL_OP_INCREMENT_AND_WRAP, // VkStencilOp passOp;
386             VK_STENCIL_OP_INCREMENT_AND_WRAP, // VkStencilOp depthFailOp;
387             VK_COMPARE_OP_ALWAYS,             // VkCompareOp compareOp;
388             0xff,                             // uint32_t compareMask;
389             0xff,                             // uint32_t writeMask;
390             0,                                // uint32_t reference;
391         };
392 
393         VkPipelineDepthStencilStateCreateInfo pipelineDepthStencilStateInfo = {
394             VK_STRUCTURE_TYPE_PIPELINE_DEPTH_STENCIL_STATE_CREATE_INFO,
395             // VkStructureType sType;
396             nullptr, // const void* pNext;
397             0,
398             // VkPipelineDepthStencilStateCreateFlags flags;
399             VK_FALSE,             // VkBool32 depthTestEnable;
400             VK_FALSE,             // VkBool32 depthWriteEnable;
401             VK_COMPARE_OP_ALWAYS, // VkCompareOp depthCompareOp;
402             VK_FALSE,             // VkBool32 depthBoundsTestEnable;
403             VK_FALSE,             // VkBool32 stencilTestEnable;
404             stencilOpState,       // VkStencilOpState front;
405             stencilOpState,       // VkStencilOpState back;
406             0.0f,                 // float minDepthBounds;
407             1.0f,                 // float maxDepthBounds;
408         };
409 
410         const VkPipelineRenderingCreateInfoKHR renderingCreateInfo = {
411             VK_STRUCTURE_TYPE_PIPELINE_RENDERING_CREATE_INFO_KHR, // VkStructureType sType;
412             nullptr,                                              // const void* pNext;
413             0u,                                                   // uint32_t viewMask;
414             1,                                                    // uint32_t colorAttachmentCount;
415             &kColorAttachmentFormat,                              // const VkFormat* pColorAttachmentFormats;
416             VK_FORMAT_UNDEFINED,                                  // VkFormat depthAttachmentFormat;
417             VK_FORMAT_UNDEFINED,                                  // VkFormat stencilAttachmentFormat;
418         };
419 
420         VkSpecializationMapEntry specializationMapEntry = {0, 0, sizeof(VkBool32)};
421         VkBool32 specializationData                     = VK_TRUE;
422         VkSpecializationInfo specializationInfo = {1, &specializationMapEntry, sizeof(VkBool32), &specializationData};
423 
424         const VkPipelineShaderStageCreateInfo pShaderStages[] = {
425             {
426                 VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // VkStructureType sType;
427                 nullptr,                                             // const void*  pNext;
428                 (VkPipelineShaderStageCreateFlags)0,                 // VkPipelineShaderStageCreateFlags flags;
429                 VK_SHADER_STAGE_VERTEX_BIT,                          // VkShaderStageFlagBits stage;
430                 *vertexModule,                                       // VkShaderModule module;
431                 "main",                                              // const char* pName;
432                 (m_testParam.dataClass == DataClass::SpecializationConstant &&
433                  m_testParam.shaderType == VK_SHADER_STAGE_VERTEX_BIT) ?
434                     &specializationInfo :
435                     nullptr, // const VkSpecializationInfo* pSpecializationInfo;
436             },
437             {
438                 VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // VkStructureType sType;
439                 nullptr,                                             // const void* pNext;
440                 (VkPipelineShaderStageCreateFlags)0,                 // VkPipelineShaderStageCreateFlags flags;
441                 VK_SHADER_STAGE_FRAGMENT_BIT,                        // VkShaderStageFlagBits stage;
442                 *fragmentModule,                                     // VkShaderModule module;
443                 "main",                                              // const char* pName;
444                 (m_testParam.dataClass == DataClass::SpecializationConstant &&
445                  m_testParam.shaderType == VK_SHADER_STAGE_FRAGMENT_BIT) ?
446                     &specializationInfo :
447                     nullptr, // const VkSpecializationInfo* pSpecializationInfo;
448             },
449         };
450 
451         const VkGraphicsPipelineCreateInfo graphicsPipelineInfo = {
452             VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO, // VkStructureType sType;
453             &renderingCreateInfo,                            // const void* pNext;
454             (VkPipelineCreateFlags)0,                        // VkPipelineCreateFlags flags;
455             2u,                                              // uint32_t stageCount;
456             pShaderStages,                                   // const VkPipelineShaderStageCreateInfo* pStages;
457             &vertexInputStateParams,         // const VkPipelineVertexInputStateCreateInfo* pVertexInputState;
458             &pipelineInputAssemblyStateInfo, // const VkPipelineInputAssemblyStateCreateInfo* pInputAssemblyState;
459             nullptr,                         // const VkPipelineTessellationStateCreateInfo* pTessellationState;
460             &pipelineViewportStateInfo,      // const VkPipelineViewportStateCreateInfo* pViewportState;
461             &pipelineRasterizationStateInfo, // const VkPipelineRasterizationStateCreateInfo* pRasterizationState;
462             &pipelineMultisampleStateInfo,   // const VkPipelineMultisampleStateCreateInfo* pMultisampleState;
463             &pipelineDepthStencilStateInfo,  // const VkPipelineDepthStencilStateCreateInfo* pDepthStencilState;
464             &pipelineColorBlendStateInfo,    // const VkPipelineColorBlendStateCreateInfo* pColorBlendState;
465             nullptr,                         // const VkPipelineDynamicStateCreateInfo* pDynamicState;
466             *m_pipelineLayout,               // VkPipelineLayout layout;
467             VK_NULL_HANDLE,                  // VkRenderPass renderPass;
468             0u,                              // uint32_t subpass;
469             VK_NULL_HANDLE,                  // VkPipeline basePipelineHandle;
470             0,                               // int32_t basePipelineIndex;
471         };
472 
473         m_pipeline = createGraphicsPipeline(m_vk, device, VK_NULL_HANDLE, &graphicsPipelineInfo);
474 
475         // DescriptorSet create/update for input storage buffer
476         if (m_testParam.dataClass == DataClass::StorageBuffer)
477         {
478             // DescriptorPool/DescriptorSet create
479             VkDescriptorPoolCreateFlags poolCreateFlags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
480 
481             vk::DescriptorPoolBuilder poolBuilder;
482             for (uint32_t i = 0; i < static_cast<uint32_t>(bindings.size()); ++i)
483             {
484                 poolBuilder.addType(bindings[i].descriptorType, bindings[i].descriptorCount);
485             }
486             m_descriptorPool = poolBuilder.build(m_vk, device, poolCreateFlags, 1);
487 
488             m_descriptorSet = makeDescriptorSet(m_vk, device, *m_descriptorPool, *m_descriptorSetLayout);
489 
490             // DescriptorSet update
491             VkDescriptorBufferInfo inputBufferInfo;
492             std::vector<VkDescriptorBufferInfo> bufferInfos;
493 
494             inputBufferInfo = makeDescriptorBufferInfo(m_inputBuffer.get(), 0, VK_WHOLE_SIZE);
495             bufferInfos.push_back(inputBufferInfo); // binding 1 is input if needed
496 
497             VkWriteDescriptorSet w = {
498                 VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,    // sType
499                 nullptr,                                   // pNext
500                 *m_descriptorSet,                          // dstSet
501                 (uint32_t)0,                               // dstBinding
502                 0,                                         // dstArrayEllement
503                 static_cast<uint32_t>(bufferInfos.size()), // descriptorCount
504                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,         // descriptorType
505                 nullptr,                                   // pImageInfo
506                 bufferInfos.data(),                        // pBufferInfo
507                 nullptr,                                   // pTexelBufferView
508             };
509 
510             m_vk.updateDescriptorSets(device, 1, &w, 0, nullptr);
511         }
512     }
513 
generateStorageBuffers()514     void generateStorageBuffers()
515     {
516         // Avoid creating zero-sized buffer/memory
517         const size_t inputBufferSize  = kNumElements * sizeof(uint64_t) * 4; // maximum size, 4 vector of 64bit
518         const size_t outputBufferSize = kNumElements * sizeof(uint32_t) * 2;
519 
520         // Upload data to buffer
521         const VkDevice device           = m_context.getDevice();
522         const uint32_t queueFamilyIndex = m_context.getUniversalQueueFamilyIndex();
523         Allocator &memAlloc             = m_context.getDefaultAllocator();
524 
525         const VkBufferCreateInfo inputBufferParams = {
526             VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO, // VkStructureType sType;
527             nullptr,                              // const void* pNext;
528             0u,                                   // VkBufferCreateFlags flags;
529             inputBufferSize,                      // VkDeviceSize size;
530             VK_BUFFER_USAGE_STORAGE_BUFFER_BIT,   // VkBufferUsageFlags usage;
531             VK_SHARING_MODE_EXCLUSIVE,            // VkSharingMode sharingMode;
532             1u,                                   // uint32_t queueFamilyCount;
533             &queueFamilyIndex                     // const uint32_t* pQueueFamilyIndices;
534         };
535 
536         m_inputBuffer   = createBuffer(m_vk, device, &inputBufferParams);
537         m_inputAlloc    = memAlloc.allocate(getBufferMemoryRequirements(m_vk, device, *m_inputBuffer),
538                                             MemoryRequirement::HostVisible);
539         void *inputData = m_inputAlloc->getHostPtr();
540 
541         // element stride of channel count 3 is 4, otherwise same to channel count
542         const uint32_t elementStride = (m_testParam.dataChannelCount != 3) ? m_testParam.dataChannelCount : 4;
543 
544         for (uint32_t i = 0; i < kNumElements; i++)
545         {
546             for (uint32_t channel = 0; channel < m_testParam.dataChannelCount; channel++)
547             {
548                 const uint32_t index = (i * elementStride) + channel;
549                 uint32_t value       = i + channel;
550                 if (m_testParam.wrongExpectation)
551                 {
552                     value += 1; // write wrong value to storage buffer
553                 }
554 
555                 switch (m_testParam.dataType)
556                 {
557                 case DataType::Bool: // std430 layout alignment of machine type(GLfloat)
558                     reinterpret_cast<int32_t *>(inputData)[index] = m_testParam.wrongExpectation ? VK_FALSE : VK_TRUE;
559                     break;
560                 case DataType::Int8:
561                     reinterpret_cast<int8_t *>(inputData)[index] = static_cast<int8_t>(value);
562                     break;
563                 case DataType::Int16:
564                     reinterpret_cast<int16_t *>(inputData)[index] = static_cast<int16_t>(value);
565                     break;
566                 case DataType::Int32:
567                     reinterpret_cast<int32_t *>(inputData)[index] = static_cast<int32_t>(value);
568                     break;
569                 case DataType::Int64:
570                     reinterpret_cast<int64_t *>(inputData)[index] = static_cast<int64_t>(value);
571                     break;
572                 default:
573                     assert(false);
574                 }
575             }
576         }
577 
578         VK_CHECK(m_vk.bindBufferMemory(device, *m_inputBuffer, m_inputAlloc->getMemory(), m_inputAlloc->getOffset()));
579         flushAlloc(m_vk, device, *m_inputAlloc);
580 
581         const VkBufferCreateInfo outputBufferParams = {
582             VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO,                                  // VkStructureType sType;
583             nullptr,                                                               // const void* pNext;
584             0u,                                                                    // VkBufferCreateFlags flags;
585             outputBufferSize,                                                      // VkDeviceSize size;
586             VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT, // VkBufferUsageFlags usage;
587             VK_SHARING_MODE_EXCLUSIVE,                                             // VkSharingMode sharingMode;
588             1u,                                                                    // uint32_t queueFamilyCount;
589             &queueFamilyIndex // const uint32_t* pQueueFamilyIndices;
590         };
591 
592         m_outputBuffer = createBuffer(m_vk, device, &outputBufferParams);
593         m_outputAlloc  = memAlloc.allocate(getBufferMemoryRequirements(m_vk, device, *m_outputBuffer),
594                                            MemoryRequirement::HostVisible);
595 
596         void *outputData = m_outputAlloc->getHostPtr();
597         deMemset(outputData, 0, sizeof(outputBufferSize));
598 
599         VK_CHECK(
600             m_vk.bindBufferMemory(device, *m_outputBuffer, m_outputAlloc->getMemory(), m_outputAlloc->getOffset()));
601         flushAlloc(m_vk, device, *m_outputAlloc);
602     }
603 
generateComputePipeline()604     void generateComputePipeline()
605     {
606         const VkDevice device = m_context.getDevice();
607 
608         const Unique<VkShaderModule> cs(
609             createShaderModule(m_vk, device, m_context.getBinaryCollection().get("comp"), 0));
610 
611         VkDescriptorSetLayoutCreateFlags layoutCreateFlags = 0;
612 
613         std::vector<VkDescriptorSetLayoutBinding> bindings;
614         bindings.emplace_back(VkDescriptorSetLayoutBinding{
615             0,                                 // binding
616             VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, // descriptorType
617             1,                                 // descriptorCount
618             VK_SHADER_STAGE_COMPUTE_BIT,       // stageFlags
619             nullptr,                           // pImmutableSamplers
620         });                                    // output binding
621 
622         if (m_testParam.dataClass == DataClass::StorageBuffer)
623         {
624             bindings.emplace_back(VkDescriptorSetLayoutBinding{
625                 1,                                 // binding
626                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, // descriptorType
627                 1,                                 // descriptorCount
628                 VK_SHADER_STAGE_COMPUTE_BIT,       // stageFlags
629                 nullptr,                           // pImmutableSamplers
630             });                                    // input binding
631         }
632 
633         // Create a layout and allocate a descriptor set for it.
634         const VkDescriptorSetLayoutCreateInfo setLayoutCreateInfo = {
635             vk::VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO, // sType
636             nullptr,                                                 // pNext
637             layoutCreateFlags,                                       // flags
638             static_cast<uint32_t>(bindings.size()),                  // bindingCount
639             bindings.data()                                          // pBindings
640         };
641 
642         m_descriptorSetLayout = vk::createDescriptorSetLayout(m_vk, device, &setLayoutCreateInfo);
643 
644         VkSpecializationMapEntry specializationMapEntry = {0, 0, sizeof(VkBool32)};
645         VkBool32 specializationData                     = VK_TRUE;
646         VkSpecializationInfo specializationInfo = {1, &specializationMapEntry, sizeof(VkBool32), &specializationData};
647         const VkPipelineShaderStageCreateInfo csShaderCreateInfo = {
648             VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
649             nullptr,
650             (VkPipelineShaderStageCreateFlags)0,
651             VK_SHADER_STAGE_COMPUTE_BIT, // stage
652             *cs,                         // shader
653             "main",
654             (m_testParam.dataClass == DataClass::SpecializationConstant) ? &specializationInfo :
655                                                                            nullptr, // pSpecializationInfo
656         };
657 
658         VkPushConstantRange pushConstantRange = {VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(VkBool32)};
659         m_pipelineLayout = makePipelineLayout(m_vk, device, 1, &m_descriptorSetLayout.get(), 1, &pushConstantRange);
660 
661         const VkComputePipelineCreateInfo pipelineCreateInfo = {
662             VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
663             nullptr,
664             0u,                 // flags
665             csShaderCreateInfo, // cs
666             *m_pipelineLayout,  // layout
667             (vk::VkPipeline)0,  // basePipelineHandle
668             0u,                 // basePipelineIndex
669         };
670 
671         m_pipeline = createComputePipeline(m_vk, device, VK_NULL_HANDLE, &pipelineCreateInfo, nullptr);
672 
673         // DescriptorSet create for input/output storage buffer
674         VkDescriptorPoolCreateFlags poolCreateFlags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
675 
676         vk::DescriptorPoolBuilder poolBuilder;
677         for (uint32_t i = 0; i < static_cast<uint32_t>(bindings.size()); ++i)
678         {
679             poolBuilder.addType(bindings[i].descriptorType, bindings[i].descriptorCount);
680         }
681         m_descriptorPool = poolBuilder.build(m_vk, device, poolCreateFlags, 1);
682 
683         m_descriptorSet = makeDescriptorSet(m_vk, device, *m_descriptorPool, *m_descriptorSetLayout);
684 
685         // DescriptorSet update
686         VkDescriptorBufferInfo outputBufferInfo;
687         VkDescriptorBufferInfo inputBufferInfo;
688         std::vector<VkDescriptorBufferInfo> bufferInfos;
689 
690         outputBufferInfo = makeDescriptorBufferInfo(m_outputBuffer.get(), 0, VK_WHOLE_SIZE);
691         bufferInfos.push_back(outputBufferInfo); // binding 0 is output
692 
693         if (m_testParam.dataClass == DataClass::StorageBuffer)
694         {
695             inputBufferInfo = makeDescriptorBufferInfo(m_inputBuffer.get(), 0, VK_WHOLE_SIZE);
696             bufferInfos.push_back(inputBufferInfo); // binding 1 is input if needed
697         }
698 
699         VkWriteDescriptorSet w = {
700             VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,    // sType
701             nullptr,                                   // pNext
702             *m_descriptorSet,                          // dstSet
703             (uint32_t)0,                               // dstBinding
704             0,                                         // dstArrayEllement
705             static_cast<uint32_t>(bufferInfos.size()), // descriptorCount
706             VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,         // descriptorType
707             nullptr,                                   // pImageInfo
708             bufferInfos.data(),                        // pBufferInfo
709             nullptr,                                   // pTexelBufferView
710         };
711 
712         m_vk.updateDescriptorSets(device, 1, &w, 0, nullptr);
713     }
714 
dispatch()715     void dispatch()
716     {
717         const VkDevice device = m_context.getDevice();
718         const VkQueue queue   = m_context.getUniversalQueue();
719 
720         beginCommandBuffer(m_vk, *m_cmdBuffer);
721         m_vk.cmdBindPipeline(*m_cmdBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, *m_pipeline);
722         m_vk.cmdBindDescriptorSets(*m_cmdBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, *m_pipelineLayout, 0u, 1,
723                                    &m_descriptorSet.get(), 0u, nullptr);
724 
725         if (m_testParam.dataClass == DataClass::PushConstant)
726         {
727             VkBool32 pcValue = VK_TRUE;
728             m_vk.cmdPushConstants(*m_cmdBuffer, *m_pipelineLayout, VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(VkBool32),
729                                   &pcValue);
730         }
731         m_vk.cmdDispatch(*m_cmdBuffer, 1, 1, 1);
732 
733         const VkMemoryBarrier barrier = {
734             VK_STRUCTURE_TYPE_MEMORY_BARRIER, // sType
735             nullptr,                          // pNext
736             VK_ACCESS_SHADER_WRITE_BIT,       // srcAccessMask
737             VK_ACCESS_HOST_READ_BIT,          // dstAccessMask
738         };
739         m_vk.cmdPipelineBarrier(*m_cmdBuffer, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_HOST_BIT,
740                                 (VkDependencyFlags)0, 1, &barrier, 0, nullptr, 0, nullptr);
741 
742         VK_CHECK(m_vk.endCommandBuffer(*m_cmdBuffer));
743         submitCommandsAndWait(m_vk, device, queue, m_cmdBuffer.get());
744         flushMappedMemoryRange(m_vk, device, m_outputAlloc->getMemory(), 0, VK_WHOLE_SIZE);
745     }
746 
render()747     void render()
748     {
749         const VkDevice device = m_context.getDevice();
750         const VkQueue queue   = m_context.getUniversalQueue();
751 
752         beginCommandBuffer(m_vk, *m_cmdBuffer);
753 
754         // begin render pass
755         const VkClearValue clearValue = {}; // { 0, 0, 0, 0 }
756         const VkRect2D renderArea     = {{0, 0}, {kNumElements, 1}};
757 
758         const VkRenderingAttachmentInfoKHR renderingAttInfo = {
759             VK_STRUCTURE_TYPE_RENDERING_ATTACHMENT_INFO_KHR, // VkStructureType sType;
760             nullptr,                                         // const void* pNext;
761             *m_imageColorView,                               // VkImageView imageView;
762             VK_IMAGE_LAYOUT_COLOR_ATTACHMENT_OPTIMAL,        // VkImageLayout imageLayout;
763             VK_RESOLVE_MODE_NONE,                            // VkResolveModeFlagBits resolveMode;
764             VK_NULL_HANDLE,                                  // VkImageView resolveImageView;
765             VK_IMAGE_LAYOUT_UNDEFINED,                       // VkImageLayout resolveImageLayout;
766             VK_ATTACHMENT_LOAD_OP_CLEAR,                     // VkAttachmentLoadOp loadOp;
767             VK_ATTACHMENT_STORE_OP_STORE,                    // VkAttachmentStoreOp storeOp;
768             clearValue,                                      // VkClearValue clearValue;
769         };
770 
771         const VkRenderingInfoKHR renderingInfo = {
772             VK_STRUCTURE_TYPE_RENDERING_INFO_KHR, // VkStructureType sType;
773             nullptr,                              // const void* pNext;
774             0,                                    // VkRenderingFlagsKHR flags;
775             renderArea,                           // VkRect2D renderArea;
776             1u,                                   // uint32_t layerCount;
777             0u,                                   // uint32_t viewMask;
778             1,                                    // uint32_t colorAttachmentCount;
779             &renderingAttInfo,                    // const VkRenderingAttachmentInfoKHR* pColorAttachments;
780             nullptr,                              // const VkRenderingAttachmentInfoKHR* pDepthAttachment;
781             nullptr                               // const VkRenderingAttachmentInfoKHR* pStencilAttachment;
782         };
783 
784         auto transition2DImage = [](const vk::DeviceInterface &vk, vk::VkCommandBuffer cmdBuffer, vk::VkImage image,
785                                     vk::VkImageAspectFlags aspectMask, vk::VkImageLayout oldLayout,
786                                     vk::VkImageLayout newLayout, vk::VkAccessFlags srcAccessMask,
787                                     vk::VkAccessFlags dstAccessMask, vk::VkPipelineStageFlags srcStageMask,
788                                     vk::VkPipelineStageFlags dstStageMask)
789         {
790             vk::VkImageMemoryBarrier barrier;
791             barrier.sType                           = vk::VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER;
792             barrier.pNext                           = nullptr;
793             barrier.srcAccessMask                   = srcAccessMask;
794             barrier.dstAccessMask                   = dstAccessMask;
795             barrier.oldLayout                       = oldLayout;
796             barrier.newLayout                       = newLayout;
797             barrier.srcQueueFamilyIndex             = VK_QUEUE_FAMILY_IGNORED;
798             barrier.dstQueueFamilyIndex             = VK_QUEUE_FAMILY_IGNORED;
799             barrier.image                           = image;
800             barrier.subresourceRange.aspectMask     = aspectMask;
801             barrier.subresourceRange.baseMipLevel   = 0;
802             barrier.subresourceRange.levelCount     = 1;
803             barrier.subresourceRange.baseArrayLayer = 0;
804             barrier.subresourceRange.layerCount     = 1;
805 
806             vk.cmdPipelineBarrier(cmdBuffer, srcStageMask, dstStageMask, (vk::VkDependencyFlags)0, 0,
807                                   (const vk::VkMemoryBarrier *)nullptr, 0, (const vk::VkBufferMemoryBarrier *)nullptr,
808                                   1, &barrier);
809         };
810 
811         transition2DImage(m_vk, *m_cmdBuffer, *m_imageColor, VK_IMAGE_ASPECT_COLOR_BIT, VK_IMAGE_LAYOUT_UNDEFINED,
812                           VK_IMAGE_LAYOUT_GENERAL, 0, VK_ACCESS_COLOR_ATTACHMENT_WRITE_BIT,
813                           VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, VK_PIPELINE_STAGE_COLOR_ATTACHMENT_OUTPUT_BIT);
814 
815         m_vk.cmdBeginRendering(*m_cmdBuffer, &renderingInfo);
816 
817         // vertex input setup
818         // pipeline setup
819         m_vk.cmdBindPipeline(*m_cmdBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, *m_pipeline);
820 
821         const uint32_t vertexCount = 6;
822         const VkDeviceSize pOffset = 0;
823         assert(vertexCount <= kNumElements);
824         if (m_testParam.dataClass == DataClass::PushConstant)
825         {
826             const VkBool32 pcValue = VK_TRUE;
827             m_vk.cmdPushConstants(*m_cmdBuffer, *m_pipelineLayout,
828                                   static_cast<VkShaderStageFlags>(m_testParam.shaderType), 0, sizeof(VkBool32),
829                                   &pcValue);
830         }
831         else if (m_testParam.dataClass == DataClass::StorageBuffer)
832         {
833             m_vk.cmdBindDescriptorSets(*m_cmdBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, *m_pipelineLayout, 0u, 1,
834                                        &m_descriptorSet.get(), 0u, nullptr);
835         }
836         m_vk.cmdBindVertexBuffers(*m_cmdBuffer, 0, 1, &m_vertexBuffer.get(), &pOffset);
837 
838         m_vk.cmdDraw(*m_cmdBuffer, vertexCount, 1, 0, 0u);
839 
840         m_vk.cmdEndRendering(*m_cmdBuffer);
841 
842         VkMemoryBarrier memBarrier = {
843             VK_STRUCTURE_TYPE_MEMORY_BARRIER, // sType
844             nullptr,                          // pNext
845             0u,                               // srcAccessMask
846             0u,                               // dstAccessMask
847         };
848         memBarrier.srcAccessMask = VK_ACCESS_COLOR_ATTACHMENT_WRITE_BIT;
849         memBarrier.dstAccessMask = VK_ACCESS_TRANSFER_READ_BIT;
850         m_vk.cmdPipelineBarrier(*m_cmdBuffer, VK_PIPELINE_STAGE_COLOR_ATTACHMENT_OUTPUT_BIT,
851                                 VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 1, &memBarrier, 0, nullptr, 0, nullptr);
852 
853         // copy color image to output buffer
854         const VkImageSubresourceLayers imageSubresource = {VK_IMAGE_ASPECT_COLOR_BIT, 0, 0, 1};
855         const VkOffset3D imageOffset                    = {};
856         const VkExtent3D imageExtent                    = {kNumElements, 1, 1};
857         const VkBufferImageCopy copyRegion              = {0, 0, 0, imageSubresource, imageOffset, imageExtent};
858 
859         m_vk.cmdCopyImageToBuffer(*m_cmdBuffer, *m_imageColor, VK_IMAGE_LAYOUT_GENERAL, *m_outputBuffer, 1,
860                                   &copyRegion);
861 
862         VK_CHECK(m_vk.endCommandBuffer(*m_cmdBuffer));
863 
864         submitCommandsAndWait(m_vk, device, queue, m_cmdBuffer.get());
865         flushMappedMemoryRange(m_vk, device, m_outputAlloc->getMemory(), 0, VK_WHOLE_SIZE);
866     }
867 
868     TestParam m_testParam;
869     const DeviceInterface &m_vk;
870 
871     Move<VkCommandPool> m_cmdPool;
872     Move<VkCommandBuffer> m_cmdBuffer;
873     Move<VkDescriptorPool> m_descriptorPool;
874     Move<VkDescriptorSet> m_descriptorSet;
875     Move<VkDescriptorSetLayout> m_descriptorSetLayout;
876     Move<VkPipelineLayout> m_pipelineLayout;
877     Move<VkPipeline> m_pipeline;
878     Move<VkBuffer> m_inputBuffer;
879     de::MovePtr<Allocation> m_inputAlloc;
880     Move<VkBuffer> m_outputBuffer;
881     de::MovePtr<Allocation> m_outputAlloc;
882     Move<VkBuffer> m_vertexBuffer;
883     de::MovePtr<Allocation> m_vertexAlloc;
884     Move<VkImage> m_imageColor;
885     de::MovePtr<Allocation> m_imageColorAlloc;
886     Move<VkImageView> m_imageColorView;
887 };
888 
889 class ShaderExpectAssumeCase : public TestCase
890 {
891 public:
ShaderExpectAssumeCase(tcu::TestContext & testCtx,TestParam testParam)892     ShaderExpectAssumeCase(tcu::TestContext &testCtx, TestParam testParam)
893         : TestCase(testCtx, testParam.testName)
894         , m_testParam(testParam)
895     {
896     }
897     ShaderExpectAssumeCase(const ShaderExpectAssumeCase &)            = delete;
898     ShaderExpectAssumeCase &operator=(const ShaderExpectAssumeCase &) = delete;
899 
createInstance(Context & ctx) const900     TestInstance *createInstance(Context &ctx) const override
901     {
902         return new ShaderExpectAssumeTestInstance(ctx, m_testParam);
903     }
904 
initPrograms(vk::SourceCollections & programCollection) const905     void initPrograms(vk::SourceCollections &programCollection) const override
906     {
907         std::map<std::string, std::string> params;
908 
909         params["TEST_ELEMENT_COUNT"] = std::to_string(kNumElements);
910         assert(kNumElements < 127); // less than int byte
911 
912         switch (m_testParam.opType)
913         {
914         case OpType::Expect:
915             params["TEST_OPERATOR"] = "expectKHR";
916             break;
917         case OpType::Assume:
918             params["TEST_OPERATOR"] = "assumeTrueKHR";
919             break;
920         default:
921             assert(false);
922         }
923 
924         // default no need additional extension.
925         params["DATATYPE_EXTENSION_ENABLE"] = "";
926 
927         switch (m_testParam.dataType)
928         {
929         case DataType::Bool:
930             if (m_testParam.dataChannelCount == 1)
931             {
932                 params["DATATYPE"] = "bool";
933             }
934             else
935             {
936                 params["DATATYPE"] = "bvec" + std::to_string(m_testParam.dataChannelCount);
937             }
938             break;
939         case DataType::Int8:
940             assert(m_testParam.opType != OpType::Assume);
941             params["DATATYPE_EXTENSION_ENABLE"] = "#extension GL_EXT_shader_explicit_arithmetic_types_int8: enable";
942             if (m_testParam.dataChannelCount == 1)
943             {
944                 params["DATATYPE"] = "int8_t";
945             }
946             else
947             {
948                 params["DATATYPE"] = "i8vec" + std::to_string(m_testParam.dataChannelCount);
949             }
950             break;
951         case DataType::Int16:
952             assert(m_testParam.opType != OpType::Assume);
953             params["DATATYPE_EXTENSION_ENABLE"] = "#extension GL_EXT_shader_explicit_arithmetic_types_int16: enable";
954             if (m_testParam.dataChannelCount == 1)
955             {
956                 params["DATATYPE"] = "int16_t";
957             }
958             else
959             {
960                 params["DATATYPE"] = "i16vec" + std::to_string(m_testParam.dataChannelCount);
961             }
962             break;
963         case DataType::Int32:
964             assert(m_testParam.opType != OpType::Assume);
965             params["DATATYPE_EXTENSION_ENABLE"] = "#extension GL_EXT_shader_explicit_arithmetic_types_int32: enable";
966             if (m_testParam.dataChannelCount == 1)
967             {
968                 params["DATATYPE"] = "int32_t";
969             }
970             else
971             {
972                 params["DATATYPE"] = "i32vec" + std::to_string(m_testParam.dataChannelCount);
973             }
974             break;
975         case DataType::Int64:
976             assert(m_testParam.opType != OpType::Assume);
977             params["DATATYPE_EXTENSION_ENABLE"] = "#extension GL_EXT_shader_explicit_arithmetic_types_int64: enable";
978             if (m_testParam.dataChannelCount == 1)
979             {
980                 params["DATATYPE"] = "int64_t";
981             }
982             else
983             {
984                 params["DATATYPE"] = "i64vec" + std::to_string(m_testParam.dataChannelCount);
985             }
986             break;
987         default:
988             assert(false);
989         }
990 
991         switch (m_testParam.dataClass)
992         {
993         case DataClass::Constant:
994             assert(m_testParam.dataChannelCount == 1);
995 
996             params["VARNAME"] = "kThisIsTrue";
997             if (m_testParam.opType == OpType::Expect)
998             {
999                 params["EXPECTEDVALUE"] = "true";
1000                 params["WRONGVALUE"]    = "false";
1001             }
1002             break;
1003         case DataClass::SpecializationConstant:
1004             assert(m_testParam.dataChannelCount == 1);
1005 
1006             params["VARNAME"] = "scThisIsTrue";
1007             if (m_testParam.opType == OpType::Expect)
1008             {
1009                 params["EXPECTEDVALUE"] = "true";
1010                 params["WRONGVALUE"]    = "false";
1011             }
1012             break;
1013         case DataClass::StorageBuffer:
1014         {
1015             std::string indexingOffset;
1016             switch (m_testParam.shaderType)
1017             {
1018             case VK_SHADER_STAGE_COMPUTE_BIT:
1019                 indexingOffset = "gl_GlobalInvocationID.x";
1020                 break;
1021             case VK_SHADER_STAGE_VERTEX_BIT:
1022                 indexingOffset = "gl_VertexIndex";
1023                 break;
1024             case VK_SHADER_STAGE_FRAGMENT_BIT:
1025                 indexingOffset = "uint(gl_FragCoord.x)";
1026                 break;
1027             default:
1028                 assert(false);
1029             }
1030 
1031             params["VARNAME"] = "inputBuffer[" + indexingOffset + "]";
1032 
1033             if (m_testParam.opType == OpType::Expect)
1034             {
1035                 if (m_testParam.dataType == DataType::Bool)
1036                 {
1037                     params["EXPECTEDVALUE"] =
1038                         params["DATATYPE"] + "(true)"; // inputBuffer should be same as invocation id
1039                     params["WRONGVALUE"] =
1040                         params["DATATYPE"] + "(false)"; // inputBuffer should be same as invocation id
1041                 }
1042                 else
1043                 {
1044                     // inputBuffer should be same as invocation id + channel
1045                     params["EXPECTEDVALUE"] = params["DATATYPE"] + "(" + indexingOffset;
1046                     for (uint32_t channel = 1; channel < m_testParam.dataChannelCount; channel++) // from channel 1
1047                     {
1048                         params["EXPECTEDVALUE"] += ", " + indexingOffset + " + " + std::to_string(channel);
1049                     }
1050                     params["EXPECTEDVALUE"] += ")";
1051 
1052                     params["WRONGVALUE"] = params["DATATYPE"] + "(" + indexingOffset + "*2 + 3";
1053                     for (uint32_t channel = 1; channel < m_testParam.dataChannelCount; channel++) // from channel 1
1054                     {
1055                         params["WRONGVALUE"] += ", " + indexingOffset + "*2 + 3" + " + " + std::to_string(channel);
1056                     }
1057                     params["WRONGVALUE"] += ")";
1058                 }
1059             }
1060             break;
1061         }
1062         case DataClass::PushConstant:
1063             assert(m_testParam.dataChannelCount == 1);
1064             params["VARNAME"] = "pcThisIsTrue";
1065 
1066             if (m_testParam.opType == OpType::Expect)
1067             {
1068                 params["EXPECTEDVALUE"] = "true";
1069                 params["WRONGVALUE"]    = "false";
1070             }
1071 
1072             break;
1073         default:
1074             assert(false);
1075         }
1076 
1077         assert(!params["VARNAME"].empty());
1078         if (params["EXPECTEDVALUE"].empty())
1079         {
1080             params["TEST_OPERANDS"] = "(" + params["VARNAME"] + ")";
1081         }
1082         else
1083         {
1084             params["TEST_OPERANDS"] = "(" + params["VARNAME"] + ", " + params["EXPECTEDVALUE"] + ")";
1085         }
1086 
1087         switch (m_testParam.shaderType)
1088         {
1089         case VK_SHADER_STAGE_COMPUTE_BIT:
1090             addComputeTestShader(programCollection, params);
1091             break;
1092         case VK_SHADER_STAGE_VERTEX_BIT:
1093             addVertexTestShaders(programCollection, params);
1094             break;
1095         case VK_SHADER_STAGE_FRAGMENT_BIT:
1096             addFragmentTestShaders(programCollection, params);
1097             break;
1098         default:
1099             assert(0);
1100         }
1101     }
1102 
checkSupport(Context & context) const1103     void checkSupport(Context &context) const override
1104     {
1105         context.requireDeviceFunctionality("VK_KHR_shader_expect_assume");
1106 
1107         const auto &features          = context.getDeviceFeatures();
1108         const auto &featuresStorage16 = context.get16BitStorageFeatures();
1109         const auto &featuresF16I8     = context.getShaderFloat16Int8Features();
1110         const auto &featuresStorage8  = context.get8BitStorageFeatures();
1111 
1112         if (m_testParam.dataType == DataType::Int64)
1113         {
1114             if (!features.shaderInt64)
1115                 TCU_THROW(NotSupportedError, "64-bit integers not supported");
1116         }
1117         else if (m_testParam.dataType == DataType::Int16)
1118         {
1119             context.requireDeviceFunctionality("VK_KHR_16bit_storage");
1120 
1121             if (!features.shaderInt16)
1122                 TCU_THROW(NotSupportedError, "16-bit integers not supported");
1123 
1124             if (!featuresStorage16.storageBuffer16BitAccess)
1125                 TCU_THROW(NotSupportedError, "16-bit storage buffer access not supported");
1126         }
1127         else if (m_testParam.dataType == DataType::Int8)
1128         {
1129             context.requireDeviceFunctionality("VK_KHR_shader_float16_int8");
1130             context.requireDeviceFunctionality("VK_KHR_8bit_storage");
1131 
1132             if (!featuresF16I8.shaderInt8)
1133                 TCU_THROW(NotSupportedError, "8-bit integers not supported");
1134 
1135             if (!featuresStorage8.storageBuffer8BitAccess)
1136                 TCU_THROW(NotSupportedError, "8-bit storage buffer access not supported");
1137 
1138             if (!featuresStorage8.uniformAndStorageBuffer8BitAccess)
1139                 TCU_THROW(NotSupportedError, "8-bit Uniform storage buffer access not supported");
1140         }
1141     }
1142 
1143 private:
addComputeTestShader(SourceCollections & programCollection,std::map<std::string,std::string> & params) const1144     void addComputeTestShader(SourceCollections &programCollection, std::map<std::string, std::string> &params) const
1145     {
1146         std::stringstream compShader;
1147 
1148         // Compute shader copies color to linear layout in buffer memory
1149         compShader << "#version 460 core\n"
1150                    << "#extension GL_EXT_spirv_intrinsics: enable\n"
1151                    << "${DATATYPE_EXTENSION_ENABLE}\n"
1152                    << "spirv_instruction (extensions = [\"SPV_KHR_expect_assume\"], capabilities = [5629], id = 5630)\n"
1153                    << "void assumeTrueKHR(bool);\n"
1154                    << "spirv_instruction (extensions = [\"SPV_KHR_expect_assume\"], capabilities = [5629], id = 5631)\n"
1155                    << "${DATATYPE} expectKHR(${DATATYPE}, ${DATATYPE});\n"
1156                    << "precision highp float;\n"
1157                    << "precision highp int;\n"
1158                    << "layout(set = 0, binding = 0, std430) buffer Block0 { uvec2 outputBuffer[]; };\n";
1159 
1160         // declare input variable.
1161         if (m_testParam.dataClass == DataClass::Constant)
1162         {
1163             compShader << "bool kThisIsTrue = true;\n";
1164         }
1165         else if (m_testParam.dataClass == DataClass::SpecializationConstant)
1166         {
1167             compShader << "layout (constant_id = 0) const bool scThisIsTrue = false;\n";
1168         }
1169         else if (m_testParam.dataClass == DataClass::PushConstant)
1170         {
1171             compShader << "layout( push_constant, std430 ) uniform pc { layout(offset = 0) bool pcThisIsTrue; };\n";
1172         }
1173         else if (m_testParam.dataClass == DataClass::StorageBuffer)
1174         {
1175             compShader << "layout(set = 0, binding = 1, std430) buffer Block1 { ${DATATYPE} inputBuffer[]; };\n";
1176         }
1177 
1178         compShader << "layout(local_size_x = ${TEST_ELEMENT_COUNT}, local_size_y = 1, local_size_z = 1) in;\n"
1179                    << "void main()\n"
1180                    << "{\n";
1181         if (m_testParam.opType == OpType::Assume)
1182         {
1183             compShader << "    ${TEST_OPERATOR} ${TEST_OPERANDS};\n";
1184         }
1185         else if (m_testParam.opType == OpType::Expect)
1186         {
1187             compShader << "    ${DATATYPE} control = ${WRONGVALUE};\n"
1188                        << "    if ( ${TEST_OPERATOR}(${VARNAME}, ${EXPECTEDVALUE}) == ${EXPECTEDVALUE} ) {\n"
1189                        << "        control = ${EXPECTEDVALUE};\n"
1190                        << "    } else {\n"
1191                        << "        // set wrong value\n"
1192                        << "        control = ${WRONGVALUE};\n"
1193                        << "    }\n";
1194         }
1195         compShader << "    outputBuffer[gl_GlobalInvocationID.x].x = gl_GlobalInvocationID.x;\n";
1196 
1197         if (params["EXPECTEDVALUE"].empty())
1198         {
1199             compShader << "    outputBuffer[gl_GlobalInvocationID.x].y = uint(${VARNAME});\n";
1200         }
1201         else
1202         {
1203             if (m_testParam.opType == OpType::Assume)
1204             {
1205                 compShader << "    outputBuffer[gl_GlobalInvocationID.x].y = uint(${VARNAME} == ${EXPECTEDVALUE});\n";
1206             }
1207             else if (m_testParam.opType == OpType::Expect)
1208             {
1209                 // when m_testParam.wrongExpectation == true, the value of ${VARNAME} is set to ${EXPECTEDVALUE} + 1
1210                 if (m_testParam.wrongExpectation)
1211                     compShader << "    outputBuffer[gl_GlobalInvocationID.x].y = uint(control == ${WRONGVALUE});\n";
1212                 else
1213                     compShader << "    outputBuffer[gl_GlobalInvocationID.x].y = uint(control == ${EXPECTEDVALUE});\n";
1214             }
1215         }
1216         compShader << "}\n";
1217 
1218         tcu::StringTemplate computeShaderTpl(compShader.str());
1219         programCollection.glslSources.add("comp") << glu::ComputeSource(computeShaderTpl.specialize(params));
1220     }
1221 
addVertexTestShaders(SourceCollections & programCollection,std::map<std::string,std::string> & params) const1222     void addVertexTestShaders(SourceCollections &programCollection, std::map<std::string, std::string> &params) const
1223     {
1224         //vertex shader
1225         std::stringstream vertShader;
1226         vertShader << "#version 460\n"
1227                    << "#extension GL_EXT_spirv_intrinsics: enable\n"
1228                    << "${DATATYPE_EXTENSION_ENABLE}\n"
1229                    << "spirv_instruction (extensions = [\"SPV_KHR_expect_assume\"], capabilities = [5629], id = 5630)\n"
1230                    << "void assumeTrueKHR(bool);\n"
1231                    << "spirv_instruction (extensions = [\"SPV_KHR_expect_assume\"], capabilities = [5629], id = 5631)\n"
1232                    << "${DATATYPE} expectKHR(${DATATYPE}, ${DATATYPE});\n"
1233                    << "precision highp float;\n"
1234                    << "precision highp int;\n"
1235                    << "layout(location = 0) in vec4 in_position;\n"
1236                    << "layout(location = 0) out flat uint value;\n";
1237 
1238         // declare input variable.
1239         if (m_testParam.dataClass == DataClass::Constant)
1240         {
1241             vertShader << "bool kThisIsTrue = true;\n";
1242         }
1243         else if (m_testParam.dataClass == DataClass::SpecializationConstant)
1244         {
1245             vertShader << "layout (constant_id = 0) const bool scThisIsTrue = false;\n";
1246         }
1247         else if (m_testParam.dataClass == DataClass::PushConstant)
1248         {
1249             vertShader << "layout( push_constant, std430 ) uniform pc { layout(offset = 0) bool pcThisIsTrue; };\n";
1250         }
1251         else if (m_testParam.dataClass == DataClass::StorageBuffer)
1252         {
1253             vertShader << "layout(set = 0, binding = 0, std430) buffer Block1 { ${DATATYPE} inputBuffer[]; };\n";
1254         }
1255 
1256         vertShader << "void main() {\n";
1257         if (m_testParam.opType == OpType::Assume)
1258         {
1259             vertShader << "    ${TEST_OPERATOR} ${TEST_OPERANDS};\n";
1260         }
1261         else if (m_testParam.opType == OpType::Expect)
1262         {
1263             vertShader << "    ${DATATYPE} control = ${WRONGVALUE};\n"
1264                        << "    if ( ${TEST_OPERATOR}(${VARNAME}, ${EXPECTEDVALUE}) == ${EXPECTEDVALUE} ) {\n"
1265                        << "        control = ${EXPECTEDVALUE};\n"
1266                        << "    } else {\n"
1267                        << "        // set wrong value\n"
1268                        << "        control = ${WRONGVALUE};\n"
1269                        << "    }\n";
1270         }
1271 
1272         vertShader << "    gl_Position  = in_position;\n";
1273 
1274         if (params["EXPECTEDVALUE"].empty())
1275         {
1276             vertShader << "    value = uint(${VARNAME});\n";
1277         }
1278         else
1279         {
1280             if (m_testParam.opType == OpType::Assume)
1281             {
1282                 vertShader << "    value = uint(${VARNAME} == ${EXPECTEDVALUE});\n";
1283             }
1284             else if (m_testParam.opType == OpType::Expect)
1285             {
1286                 // when m_testParam.wrongExpectation == true, the value of ${VARNAME} is set to ${EXPECTEDVALUE} + 1
1287                 if (m_testParam.wrongExpectation)
1288                     vertShader << "    value = uint(control == ${WRONGVALUE});\n";
1289                 else
1290                     vertShader << "    value = uint(control == ${EXPECTEDVALUE});\n";
1291             }
1292         }
1293         vertShader << "}\n";
1294 
1295         tcu::StringTemplate vertexShaderTpl(vertShader.str());
1296         programCollection.glslSources.add("vert") << glu::VertexSource(vertexShaderTpl.specialize(params));
1297 
1298         // fragment shader
1299         std::stringstream fragShader;
1300         fragShader << "#version 460\n"
1301                    << "precision highp float;\n"
1302                    << "precision highp int;\n"
1303                    << "layout(location = 0) in flat uint value;\n"
1304                    << "layout(location = 0) out uvec2 out_color;\n"
1305                    << "void main()\n"
1306                    << "{\n"
1307                    << "    out_color.r = uint(gl_FragCoord.x);\n"
1308                    << "    out_color.g = value;\n"
1309                    << "}\n";
1310 
1311         tcu::StringTemplate fragmentShaderTpl(fragShader.str());
1312         programCollection.glslSources.add("frag") << glu::FragmentSource(fragmentShaderTpl.specialize(params));
1313     }
1314 
addFragmentTestShaders(SourceCollections & programCollection,std::map<std::string,std::string> & params) const1315     void addFragmentTestShaders(SourceCollections &programCollection, std::map<std::string, std::string> &params) const
1316     {
1317         //vertex shader
1318         std::stringstream vertShader;
1319         vertShader << "#version 460\n"
1320                    << "precision highp float;\n"
1321                    << "precision highp int;\n"
1322                    << "layout(location = 0) in vec4 in_position;\n"
1323                    << "void main() {\n"
1324                    << "    gl_Position  = in_position;\n"
1325                    << "}\n";
1326 
1327         tcu::StringTemplate vertexShaderTpl(vertShader.str());
1328         programCollection.glslSources.add("vert") << glu::VertexSource(vertexShaderTpl.specialize(params));
1329 
1330         // fragment shader
1331         std::stringstream fragShader;
1332         fragShader << "#version 460\n"
1333                    << "#extension GL_EXT_spirv_intrinsics: enable\n"
1334                    << "${DATATYPE_EXTENSION_ENABLE}\n"
1335                    << "spirv_instruction (extensions = [\"SPV_KHR_expect_assume\"], capabilities = [5629], id = 5630)\n"
1336                    << "void assumeTrueKHR(bool);\n"
1337                    << "spirv_instruction (extensions = [\"SPV_KHR_expect_assume\"], capabilities = [5629], id = 5631)\n"
1338                    << "${DATATYPE} expectKHR(${DATATYPE}, ${DATATYPE});\n"
1339                    << "precision highp float;\n"
1340                    << "precision highp int;\n"
1341                    << "layout(location = 0) out uvec2 out_color;\n";
1342         if (m_testParam.dataClass == DataClass::Constant)
1343         {
1344             fragShader << "bool kThisIsTrue = true;\n";
1345         }
1346         else if (m_testParam.dataClass == DataClass::SpecializationConstant)
1347         {
1348             fragShader << "layout (constant_id = 0) const bool scThisIsTrue = false;\n";
1349         }
1350         else if (m_testParam.dataClass == DataClass::PushConstant)
1351         {
1352             fragShader << "layout( push_constant, std430 ) uniform pc { layout(offset = 0) bool pcThisIsTrue; };\n";
1353         }
1354         else if (m_testParam.dataClass == DataClass::StorageBuffer)
1355         {
1356             fragShader << "layout(set = 0, binding = 0, std430) buffer Block1 { ${DATATYPE} inputBuffer[]; };\n";
1357         }
1358 
1359         fragShader << "void main()\n"
1360                    << "{\n";
1361 
1362         if (m_testParam.opType == OpType::Assume)
1363         {
1364             fragShader << "    ${TEST_OPERATOR} ${TEST_OPERANDS};\n";
1365         }
1366         else if (m_testParam.opType == OpType::Expect)
1367         {
1368             fragShader << "    ${DATATYPE} control = ${WRONGVALUE};\n"
1369                        << "    if ( ${TEST_OPERATOR}(${VARNAME}, ${EXPECTEDVALUE}) == ${EXPECTEDVALUE} ) {\n"
1370                        << "        control = ${EXPECTEDVALUE};\n"
1371                        << "    } else {\n"
1372                        << "        // set wrong value\n"
1373                        << "        control = ${WRONGVALUE};\n"
1374                        << "    }\n";
1375         }
1376         fragShader << "    out_color.r = int(gl_FragCoord.x);\n";
1377 
1378         if (params["EXPECTEDVALUE"].empty())
1379         {
1380             fragShader << "    out_color.g = uint(${VARNAME});\n";
1381         }
1382         else
1383         {
1384             if (m_testParam.opType == OpType::Assume)
1385             {
1386                 fragShader << "    out_color.g = uint(${VARNAME} == ${EXPECTEDVALUE});\n";
1387             }
1388             else if (m_testParam.opType == OpType::Expect)
1389             {
1390                 // when m_testParam.wrongExpectation == true, the value of ${VARNAME} is set to ${EXPECTEDVALUE} + 1
1391                 if (m_testParam.wrongExpectation)
1392                     fragShader << "    out_color.g = uint(control == ${WRONGVALUE});\n";
1393                 else
1394                     fragShader << "    out_color.g = uint(control == ${EXPECTEDVALUE});\n";
1395             }
1396         }
1397         fragShader << "}\n";
1398 
1399         tcu::StringTemplate fragmentShaderTpl(fragShader.str());
1400         programCollection.glslSources.add("frag") << glu::FragmentSource(fragmentShaderTpl.specialize(params));
1401     }
1402 
1403 private:
1404     TestParam m_testParam;
1405 };
1406 
addShaderExpectAssumeTests(tcu::TestCaseGroup * testGroup)1407 void addShaderExpectAssumeTests(tcu::TestCaseGroup *testGroup)
1408 {
1409     VkShaderStageFlagBits stages[] = {
1410         VK_SHADER_STAGE_VERTEX_BIT,
1411         VK_SHADER_STAGE_FRAGMENT_BIT,
1412         VK_SHADER_STAGE_COMPUTE_BIT,
1413     };
1414 
1415     TestParam testParams[] = {
1416         {OpType::Expect, DataClass::Constant, DataType::Bool, 0, VK_SHADER_STAGE_ALL, false, "constant"},
1417         {OpType::Expect, DataClass::SpecializationConstant, DataType::Bool, 0, VK_SHADER_STAGE_ALL, false,
1418          "specializationconstant"},
1419         {OpType::Expect, DataClass::PushConstant, DataType::Bool, 0, VK_SHADER_STAGE_ALL, false, "pushconstant"},
1420         {OpType::Expect, DataClass::StorageBuffer, DataType::Bool, 0, VK_SHADER_STAGE_ALL, false, "storagebuffer_bool"},
1421         {OpType::Expect, DataClass::StorageBuffer, DataType::Int8, 0, VK_SHADER_STAGE_ALL, false, "storagebuffer_int8"},
1422         {OpType::Expect, DataClass::StorageBuffer, DataType::Int16, 0, VK_SHADER_STAGE_ALL, false,
1423          "storagebuffer_int16"},
1424         {OpType::Expect, DataClass::StorageBuffer, DataType::Int32, 0, VK_SHADER_STAGE_ALL, false,
1425          "storagebuffer_int32"},
1426         {OpType::Expect, DataClass::StorageBuffer, DataType::Int64, 0, VK_SHADER_STAGE_ALL, false,
1427          "storagebuffer_int64"},
1428         {OpType::Assume, DataClass::Constant, DataType::Bool, 0, VK_SHADER_STAGE_ALL, false, "constant"},
1429         {OpType::Assume, DataClass::SpecializationConstant, DataType::Bool, 0, VK_SHADER_STAGE_ALL, false,
1430          "specializationconstant"},
1431         {OpType::Assume, DataClass::PushConstant, DataType::Bool, 0, VK_SHADER_STAGE_ALL, false, "pushconstant"},
1432         {OpType::Assume, DataClass::StorageBuffer, DataType::Bool, 0, VK_SHADER_STAGE_ALL, false, "storagebuffer"},
1433     };
1434 
1435     tcu::TestContext &testCtx = testGroup->getTestContext();
1436 
1437     for (VkShaderStageFlagBits stage : stages)
1438     {
1439         const char *stageName = (stage == VK_SHADER_STAGE_VERTEX_BIT)   ? ("vertex") :
1440                                 (stage == VK_SHADER_STAGE_FRAGMENT_BIT) ? ("fragment") :
1441                                 (stage == VK_SHADER_STAGE_COMPUTE_BIT)  ? ("compute") :
1442                                                                           (nullptr);
1443 
1444         const std::string setName = std::string() + stageName;
1445         de::MovePtr<tcu::TestCaseGroup> stageGroupTest(
1446             new tcu::TestCaseGroup(testCtx, setName.c_str(), "Shader Expect Assume Tests"));
1447 
1448         de::MovePtr<tcu::TestCaseGroup> expectGroupTest(
1449             new tcu::TestCaseGroup(testCtx, "expect", "Shader Expect Tests"));
1450 
1451         de::MovePtr<tcu::TestCaseGroup> assumeGroupTest(
1452             new tcu::TestCaseGroup(testCtx, "assume", "Shader Assume Tests"));
1453 
1454         for (uint32_t expectationState = 0; expectationState < 2; expectationState++)
1455         {
1456             bool wrongExpected = (expectationState == 0) ? false : true;
1457             for (uint32_t channelCount = 1; channelCount <= 4; channelCount++)
1458             {
1459                 for (TestParam testParam : testParams)
1460                 {
1461                     testParam.dataChannelCount = channelCount;
1462                     testParam.wrongExpectation = wrongExpected;
1463                     if (channelCount > 1 || wrongExpected)
1464                     {
1465                         if (testParam.opType != OpType::Expect || testParam.dataClass != DataClass::StorageBuffer)
1466                         {
1467                             continue;
1468                         }
1469 
1470                         if (channelCount > 1)
1471                         {
1472                             testParam.testName = testParam.testName + "_vec" + std::to_string(channelCount);
1473                         }
1474 
1475                         if (wrongExpected)
1476                         {
1477                             testParam.testName = testParam.testName + "_wrong_expected";
1478                         }
1479                     }
1480 
1481                     testParam.shaderType = stage;
1482 
1483                     switch (testParam.opType)
1484                     {
1485                     case OpType::Expect:
1486                         expectGroupTest->addChild(new ShaderExpectAssumeCase(testCtx, testParam));
1487                         break;
1488                     case OpType::Assume:
1489                         assumeGroupTest->addChild(new ShaderExpectAssumeCase(testCtx, testParam));
1490                         break;
1491                     default:
1492                         assert(false);
1493                     }
1494                 }
1495             }
1496         }
1497 
1498         stageGroupTest->addChild(expectGroupTest.release());
1499         stageGroupTest->addChild(assumeGroupTest.release());
1500 
1501         testGroup->addChild(stageGroupTest.release());
1502     }
1503 }
1504 
1505 } // namespace
1506 
createShaderExpectAssumeTests(tcu::TestContext & testCtx)1507 tcu::TestCaseGroup *createShaderExpectAssumeTests(tcu::TestContext &testCtx)
1508 {
1509     return createTestGroup(testCtx, "shader_expect_assume", addShaderExpectAssumeTests);
1510 }
1511 
1512 } // namespace shaderexecutor
1513 } // namespace vkt
1514