xref: /aosp_15_r20/external/swiftshader/tests/VulkanUnitTests/ComputeTests.cpp (revision 03ce13f70fcc45d86ee91b7ee4cab1936a95046e)
1 // Copyright 2021 The SwiftShader Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //    http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "Device.hpp"
16 #include "Driver.hpp"
17 
18 #include "gmock/gmock.h"
19 #include "gtest/gtest.h"
20 
21 #include "spirv-tools/libspirv.hpp"
22 
23 #include <cstring>
24 #include <sstream>
25 
26 namespace {
alignUp(size_t val,size_t alignment)27 size_t alignUp(size_t val, size_t alignment)
28 {
29 	return alignment * ((val + alignment - 1) / alignment);
30 }
31 }  // anonymous namespace
32 
33 struct ComputeParams
34 {
35 	size_t numElements;
36 	int localSizeX;
37 	int localSizeY;
38 	int localSizeZ;
39 
operator <<(std::ostream & os,const ComputeParams & params)40 	friend std::ostream &operator<<(std::ostream &os, const ComputeParams &params)
41 	{
42 		return os << "ComputeParams{"
43 		          << "numElements: " << params.numElements << ", "
44 		          << "localSizeX: " << params.localSizeX << ", "
45 		          << "localSizeY: " << params.localSizeY << ", "
46 		          << "localSizeZ: " << params.localSizeZ << "}";
47 	}
48 };
49 
50 class ComputeTest : public testing::TestWithParam<ComputeParams>
51 {
52 protected:
53 	static Driver driver;
54 
SetUpTestSuite()55 	static void SetUpTestSuite()
56 	{
57 		ASSERT_TRUE(driver.loadSwiftShader());
58 	}
59 
TearDownTestSuite()60 	static void TearDownTestSuite()
61 	{
62 		driver.unload();
63 	}
64 };
65 
66 Driver ComputeTest::driver;
67 
compileSpirv(const char * assembly)68 std::vector<uint32_t> compileSpirv(const char *assembly)
69 {
70 	spvtools::SpirvTools core(SPV_ENV_VULKAN_1_0);
71 
72 	core.SetMessageConsumer([](spv_message_level_t, const char *, const spv_position_t &p, const char *m) {
73 		FAIL() << p.line << ":" << p.column << ": " << m;
74 	});
75 
76 	std::vector<uint32_t> spirv;
77 	EXPECT_TRUE(core.Assemble(assembly, &spirv));
78 	EXPECT_TRUE(core.Validate(spirv));
79 
80 	// Warn if the disassembly does not match the source assembly.
81 	// We do this as debugging tests in the debugger is often made much harder
82 	// if the SSA names (%X) in the debugger do not match the source.
83 	std::string disassembled;
84 	core.Disassemble(spirv, &disassembled, SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
85 	if(disassembled != assembly)
86 	{
87 		printf("-- WARNING: Disassembly does not match assembly: ---\n\n");
88 
89 		auto splitLines = [](const std::string &str) -> std::vector<std::string> {
90 			std::stringstream ss(str);
91 			std::vector<std::string> out;
92 			std::string line;
93 			while(std::getline(ss, line, '\n')) { out.push_back(line); }
94 			return out;
95 		};
96 
97 		auto srcLines = splitLines(std::string(assembly));
98 		auto disLines = splitLines(disassembled);
99 
100 		for(size_t line = 0; line < srcLines.size() && line < disLines.size(); line++)
101 		{
102 			auto srcLine = (line < srcLines.size()) ? srcLines[line] : "<missing>";
103 			auto disLine = (line < disLines.size()) ? disLines[line] : "<missing>";
104 			if(srcLine != disLine)
105 			{
106 				printf("%zu: '%s' != '%s'\n", line, srcLine.c_str(), disLine.c_str());
107 			}
108 		}
109 		printf("\n\n---\nExpected:\n\n%s", disassembled.c_str());
110 	}
111 
112 	return spirv;
113 }
114 
115 #define VK_ASSERT(x) ASSERT_EQ(x, VK_SUCCESS)
116 
117 // Base class for compute tests that read from an input buffer and write to an
118 // output buffer of same length.
119 class SwiftShaderVulkanBufferToBufferComputeTest : public ComputeTest
120 {
121 public:
122 	void test(const std::string &shader,
123 	          std::function<uint32_t(uint32_t idx)> input,
124 	          std::function<uint32_t(uint32_t idx)> expected);
125 };
126 
test(const std::string & shader,std::function<uint32_t (uint32_t idx)> input,std::function<uint32_t (uint32_t idx)> expected)127 void SwiftShaderVulkanBufferToBufferComputeTest::test(
128     const std::string &shader,
129     std::function<uint32_t(uint32_t idx)> input,
130     std::function<uint32_t(uint32_t idx)> expected)
131 {
132 	auto code = compileSpirv(shader.c_str());
133 
134 	const VkInstanceCreateInfo createInfo = {
135 		VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO,  // sType
136 		nullptr,                                 // pNext
137 		0,                                       // flags
138 		nullptr,                                 // pApplicationInfo
139 		0,                                       // enabledLayerCount
140 		nullptr,                                 // ppEnabledLayerNames
141 		0,                                       // enabledExtensionCount
142 		nullptr,                                 // ppEnabledExtensionNames
143 	};
144 
145 	VkInstance instance = VK_NULL_HANDLE;
146 	VK_ASSERT(driver.vkCreateInstance(&createInfo, nullptr, &instance));
147 
148 	ASSERT_TRUE(driver.resolve(instance));
149 
150 	std::unique_ptr<Device> device;
151 	VK_ASSERT(Device::CreateComputeDevice(&driver, instance, device));
152 	ASSERT_TRUE(device->IsValid());
153 
154 	// struct Buffers
155 	// {
156 	//     uint32_t pad0[63];
157 	//     uint32_t magic0;
158 	//     uint32_t in[NUM_ELEMENTS]; // Aligned to 0x100
159 	//     uint32_t magic1;
160 	//     uint32_t pad1[N];
161 	//     uint32_t magic2;
162 	//     uint32_t out[NUM_ELEMENTS]; // Aligned to 0x100
163 	//     uint32_t magic3;
164 	// };
165 	static constexpr uint32_t magic0 = 0x01234567;
166 	static constexpr uint32_t magic1 = 0x89abcdef;
167 	static constexpr uint32_t magic2 = 0xfedcba99;
168 	static constexpr uint32_t magic3 = 0x87654321;
169 	size_t numElements = GetParam().numElements;
170 	size_t alignElements = 0x100 / sizeof(uint32_t);
171 	size_t magic0Offset = alignElements - 1;
172 	size_t inOffset = 1 + magic0Offset;
173 	size_t magic1Offset = numElements + inOffset;
174 	size_t magic2Offset = alignUp(magic1Offset + 1, alignElements) - 1;
175 	size_t outOffset = 1 + magic2Offset;
176 	size_t magic3Offset = numElements + outOffset;
177 	size_t buffersTotalElements = alignUp(1 + magic3Offset, alignElements);
178 	size_t buffersSize = sizeof(uint32_t) * buffersTotalElements;
179 
180 	VkDeviceMemory memory;
181 	VK_ASSERT(device->AllocateMemory(buffersSize,
182 	                                 VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT,
183 	                                 &memory));
184 
185 	uint32_t *buffers;
186 	VK_ASSERT(device->MapMemory(memory, 0, buffersSize, 0, (void **)&buffers));
187 
188 	memset(buffers, 0, buffersSize);
189 
190 	buffers[magic0Offset] = magic0;
191 	buffers[magic1Offset] = magic1;
192 	buffers[magic2Offset] = magic2;
193 	buffers[magic3Offset] = magic3;
194 
195 	for(size_t i = 0; i < numElements; i++)
196 	{
197 		buffers[inOffset + i] = input((uint32_t)i);
198 	}
199 
200 	device->UnmapMemory(memory);
201 	buffers = nullptr;
202 
203 	VkBuffer bufferIn;
204 	VK_ASSERT(device->CreateStorageBuffer(memory,
205 	                                      sizeof(uint32_t) * numElements,
206 	                                      sizeof(uint32_t) * inOffset,
207 	                                      &bufferIn));
208 
209 	VkBuffer bufferOut;
210 	VK_ASSERT(device->CreateStorageBuffer(memory,
211 	                                      sizeof(uint32_t) * numElements,
212 	                                      sizeof(uint32_t) * outOffset,
213 	                                      &bufferOut));
214 
215 	VkShaderModule shaderModule;
216 	VK_ASSERT(device->CreateShaderModule(code, &shaderModule));
217 
218 	std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings = {
219 		{
220 		    0,                                  // binding
221 		    VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,  // descriptorType
222 		    1,                                  // descriptorCount
223 		    VK_SHADER_STAGE_COMPUTE_BIT,        // stageFlags
224 		    0,                                  // pImmutableSamplers
225 		},
226 		{
227 		    1,                                  // binding
228 		    VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,  // descriptorType
229 		    1,                                  // descriptorCount
230 		    VK_SHADER_STAGE_COMPUTE_BIT,        // stageFlags
231 		    0,                                  // pImmutableSamplers
232 		}
233 	};
234 
235 	VkDescriptorSetLayout descriptorSetLayout;
236 	VK_ASSERT(device->CreateDescriptorSetLayout(descriptorSetLayoutBindings, &descriptorSetLayout));
237 
238 	VkPipelineLayout pipelineLayout;
239 	VK_ASSERT(device->CreatePipelineLayout(descriptorSetLayout, &pipelineLayout));
240 
241 	VkPipeline pipeline;
242 	VK_ASSERT(device->CreateComputePipeline(shaderModule, pipelineLayout, &pipeline));
243 
244 	VkDescriptorPool descriptorPool;
245 	VK_ASSERT(device->CreateStorageBufferDescriptorPool(2, &descriptorPool));
246 
247 	VkDescriptorSet descriptorSet;
248 	VK_ASSERT(device->AllocateDescriptorSet(descriptorPool, descriptorSetLayout, &descriptorSet));
249 
250 	std::vector<VkDescriptorBufferInfo> descriptorBufferInfos = {
251 		{
252 		    bufferIn,       // buffer
253 		    0,              // offset
254 		    VK_WHOLE_SIZE,  // range
255 		},
256 		{
257 		    bufferOut,      // buffer
258 		    0,              // offset
259 		    VK_WHOLE_SIZE,  // range
260 		}
261 	};
262 	device->UpdateStorageBufferDescriptorSets(descriptorSet, descriptorBufferInfos);
263 
264 	VkCommandPool commandPool;
265 	VK_ASSERT(device->CreateCommandPool(&commandPool));
266 
267 	VkCommandBuffer commandBuffer;
268 	VK_ASSERT(device->AllocateCommandBuffer(commandPool, &commandBuffer));
269 
270 	VK_ASSERT(device->BeginCommandBuffer(VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT, commandBuffer));
271 
272 	driver.vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
273 
274 	driver.vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipelineLayout, 0, 1, &descriptorSet,
275 	                               0, nullptr);
276 
277 	driver.vkCmdDispatch(commandBuffer, (uint32_t)(numElements / GetParam().localSizeX), 1, 1);
278 
279 	VK_ASSERT(driver.vkEndCommandBuffer(commandBuffer));
280 
281 	VK_ASSERT(device->QueueSubmitAndWait(commandBuffer));
282 
283 	VK_ASSERT(device->MapMemory(memory, 0, buffersSize, 0, (void **)&buffers));
284 
285 	for(size_t i = 0; i < numElements; ++i)
286 	{
287 		auto got = buffers[i + outOffset];
288 		EXPECT_EQ(expected((uint32_t)i), got) << "Unexpected output at " << i;
289 	}
290 
291 	// Check for writes outside of bounds.
292 	EXPECT_EQ(buffers[magic0Offset], magic0);
293 	EXPECT_EQ(buffers[magic1Offset], magic1);
294 	EXPECT_EQ(buffers[magic2Offset], magic2);
295 	EXPECT_EQ(buffers[magic3Offset], magic3);
296 
297 	device->UnmapMemory(memory);
298 	buffers = nullptr;
299 
300 	device->FreeCommandBuffer(commandPool, commandBuffer);
301 	device->FreeMemory(memory);
302 	device->DestroyPipeline(pipeline);
303 	device->DestroyCommandPool(commandPool);
304 	device->DestroyPipelineLayout(pipelineLayout);
305 	device->DestroyDescriptorSetLayout(descriptorSetLayout);
306 	device->DestroyDescriptorPool(descriptorPool);
307 	device->DestroyBuffer(bufferIn);
308 	device->DestroyBuffer(bufferOut);
309 	device->DestroyShaderModule(shaderModule);
310 	device.reset(nullptr);
311 	driver.vkDestroyInstance(instance, nullptr);
312 }
313 
314 INSTANTIATE_TEST_SUITE_P(ComputeParams, SwiftShaderVulkanBufferToBufferComputeTest, testing::Values(ComputeParams{ 512, 1, 1, 1 }, ComputeParams{ 512, 2, 1, 1 }, ComputeParams{ 512, 4, 1, 1 }, ComputeParams{ 512, 8, 1, 1 }, ComputeParams{ 512, 16, 1, 1 }, ComputeParams{ 512, 32, 1, 1 },
315 
316                                                                                                     // Non-multiple of SIMD-lane.
317                                                                                                     ComputeParams{ 3, 1, 1, 1 }, ComputeParams{ 2, 1, 1, 1 }));
318 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,Memcpy)319 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, Memcpy)
320 {
321 	std::stringstream src;
322 	// #version 450
323 	// layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
324 	// layout(binding = 0, std430) buffer InBuffer
325 	// {
326 	//     int Data[];
327 	// } In;
328 	// layout(binding = 1, std430) buffer OutBuffer
329 	// {
330 	//     int Data[];
331 	// } Out;
332 	// void main()
333 	// {
334 	//     Out.Data[gl_GlobalInvocationID.x] = In.Data[gl_GlobalInvocationID.x];
335 	// }
336 	// clang-format off
337     src <<
338         "OpCapability Shader\n"
339         "OpMemoryModel Logical GLSL450\n"
340         "OpEntryPoint GLCompute %1 \"main\" %2\n"
341         "OpExecutionMode %1 LocalSize " <<
342         GetParam().localSizeX << " " <<
343         GetParam().localSizeY << " " <<
344         GetParam().localSizeZ << "\n" <<
345         "OpDecorate %3 ArrayStride 4\n"
346         "OpMemberDecorate %4 0 Offset 0\n"
347         "OpDecorate %4 BufferBlock\n"
348         "OpDecorate %5 DescriptorSet 0\n"
349         "OpDecorate %5 Binding 1\n"
350         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
351         "OpDecorate %6 DescriptorSet 0\n"
352         "OpDecorate %6 Binding 0\n"
353         "%7 = OpTypeVoid\n"
354         "%8 = OpTypeFunction %7\n"             // void()
355         "%9 = OpTypeInt 32 1\n"                // int32
356         "%10 = OpTypeInt 32 0\n"                // uint32
357         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
358         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
359         "%11 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
360         "%5 = OpVariable %11 Uniform\n"        // struct{ int32[] }* in
361         "%12 = OpConstant %9 0\n"               // int32(0)
362         "%13 = OpConstant %10 0\n"              // uint32(0)
363         "%14 = OpTypeVector %10 3\n"            // vec3<int32>
364         "%15 = OpTypePointer Input %14\n"       // vec3<int32>*
365         "%2 = OpVariable %15 Input\n"          // gl_GlobalInvocationId
366         "%16 = OpTypePointer Input %10\n"       // uint32*
367         "%6 = OpVariable %11 Uniform\n"        // struct{ int32[] }* out
368         "%17 = OpTypePointer Uniform %9\n"      // int32*
369         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
370         "%18 = OpLabel\n"
371         "%19 = OpAccessChain %16 %2 %13\n"      // &gl_GlobalInvocationId.x
372         "%20 = OpLoad %10 %19\n"                // gl_GlobalInvocationId.x
373         "%21 = OpAccessChain %17 %6 %12 %20\n"  // &in.arr[gl_GlobalInvocationId.x]
374         "%22 = OpLoad %9 %21\n"                 // in.arr[gl_GlobalInvocationId.x]
375         "%23 = OpAccessChain %17 %5 %12 %20\n"  // &out.arr[gl_GlobalInvocationId.x]
376         "OpStore %23 %22\n"               // out.arr[gl_GlobalInvocationId.x] = in[gl_GlobalInvocationId.x]
377         "OpReturn\n"
378         "OpFunctionEnd\n";
379 	// clang-format on
380 
381 	test(
382 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });
383 }
384 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,GlobalInvocationId)385 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, GlobalInvocationId)
386 {
387 	std::stringstream src;
388 	// clang-format off
389     src <<
390         "OpCapability Shader\n"
391         "OpMemoryModel Logical GLSL450\n"
392         "OpEntryPoint GLCompute %1 \"main\" %2\n"
393         "OpExecutionMode %1 LocalSize " <<
394         GetParam().localSizeX << " " <<
395         GetParam().localSizeY << " " <<
396         GetParam().localSizeZ << "\n" <<
397         "OpDecorate %3 ArrayStride 4\n"
398         "OpMemberDecorate %4 0 Offset 0\n"
399         "OpDecorate %4 BufferBlock\n"
400         "OpDecorate %5 DescriptorSet 0\n"
401         "OpDecorate %5 Binding 1\n"
402         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
403         "OpDecorate %6 DescriptorSet 0\n"
404         "OpDecorate %6 Binding 0\n"
405         "%7 = OpTypeVoid\n"
406         "%8 = OpTypeFunction %7\n"             // void()
407         "%9 = OpTypeInt 32 1\n"                // int32
408         "%10 = OpTypeInt 32 0\n"                // uint32
409         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
410         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
411         "%11 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
412         "%5 = OpVariable %11 Uniform\n"        // struct{ int32[] }* in
413         "%12 = OpConstant %9 0\n"               // int32(0)
414         "%13 = OpConstant %9 1\n"               // int32(1)
415         "%14 = OpConstant %10 0\n"              // uint32(0)
416         "%15 = OpConstant %10 1\n"              // uint32(1)
417         "%16 = OpConstant %10 2\n"              // uint32(2)
418         "%17 = OpTypeVector %10 3\n"            // vec3<int32>
419         "%18 = OpTypePointer Input %17\n"       // vec3<int32>*
420         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
421         "%19 = OpTypePointer Input %10\n"       // uint32*
422         "%6 = OpVariable %11 Uniform\n"        // struct{ int32[] }* out
423         "%20 = OpTypePointer Uniform %9\n"      // int32*
424         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
425         "%21 = OpLabel\n"
426         "%22 = OpAccessChain %19 %2 %14\n"      // &gl_GlobalInvocationId.x
427         "%23 = OpAccessChain %19 %2 %15\n"      // &gl_GlobalInvocationId.y
428         "%24 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.z
429         "%25 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
430         "%26 = OpLoad %10 %23\n"                // gl_GlobalInvocationId.y
431         "%27 = OpLoad %10 %24\n"                // gl_GlobalInvocationId.z
432         "%28 = OpAccessChain %20 %6 %12 %25\n"  // &in.arr[gl_GlobalInvocationId.x]
433         "%29 = OpLoad %9 %28\n"                 // out.arr[gl_GlobalInvocationId.x]
434         "%30 = OpIAdd %9 %29 %26\n"             // in[gl_GlobalInvocationId.x] + gl_GlobalInvocationId.y
435         "%31 = OpIAdd %9 %30 %27\n"             // in[gl_GlobalInvocationId.x] + gl_GlobalInvocationId.y + gl_GlobalInvocationId.z
436         "%32 = OpAccessChain %20 %5 %12 %25\n"  // &out.arr[gl_GlobalInvocationId.x]
437         "OpStore %32 %31\n"               // out.arr[gl_GlobalInvocationId.x] = in[gl_GlobalInvocationId.x] + gl_GlobalInvocationId.y + gl_GlobalInvocationId.z
438         "OpReturn\n"
439         "OpFunctionEnd\n";
440 	// clang-format on
441 
442 	// gl_GlobalInvocationId.y and gl_GlobalInvocationId.z should both be zero.
443 	test(
444 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });
445 }
446 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,BranchSimple)447 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchSimple)
448 {
449 	std::stringstream src;
450 	// clang-format off
451     src <<
452         "OpCapability Shader\n"
453         "OpMemoryModel Logical GLSL450\n"
454         "OpEntryPoint GLCompute %1 \"main\" %2\n"
455         "OpExecutionMode %1 LocalSize " <<
456         GetParam().localSizeX << " " <<
457         GetParam().localSizeY << " " <<
458         GetParam().localSizeZ << "\n" <<
459         "OpDecorate %3 ArrayStride 4\n"
460         "OpMemberDecorate %4 0 Offset 0\n"
461         "OpDecorate %4 BufferBlock\n"
462         "OpDecorate %5 DescriptorSet 0\n"
463         "OpDecorate %5 Binding 1\n"
464         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
465         "OpDecorate %6 DescriptorSet 0\n"
466         "OpDecorate %6 Binding 0\n"
467         "%7 = OpTypeVoid\n"
468         "%8 = OpTypeFunction %7\n"             // void()
469         "%9 = OpTypeInt 32 1\n"                // int32
470         "%10 = OpTypeInt 32 0\n"                // uint32
471         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
472         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
473         "%11 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
474         "%5 = OpVariable %11 Uniform\n"        // struct{ int32[] }* in
475         "%12 = OpConstant %9 0\n"               // int32(0)
476         "%13 = OpConstant %10 0\n"              // uint32(0)
477         "%14 = OpTypeVector %10 3\n"            // vec3<int32>
478         "%15 = OpTypePointer Input %14\n"       // vec3<int32>*
479         "%2 = OpVariable %15 Input\n"          // gl_GlobalInvocationId
480         "%16 = OpTypePointer Input %10\n"       // uint32*
481         "%6 = OpVariable %11 Uniform\n"        // struct{ int32[] }* out
482         "%17 = OpTypePointer Uniform %9\n"      // int32*
483         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
484         "%18 = OpLabel\n"
485         "%19 = OpAccessChain %16 %2 %13\n"      // &gl_GlobalInvocationId.x
486         "%20 = OpLoad %10 %19\n"                // gl_GlobalInvocationId.x
487         "%21 = OpAccessChain %17 %6 %12 %20\n"  // &in.arr[gl_GlobalInvocationId.x]
488         "%22 = OpLoad %9 %21\n"                 // in.arr[gl_GlobalInvocationId.x]
489         "%23 = OpAccessChain %17 %5 %12 %20\n"  // &out.arr[gl_GlobalInvocationId.x]
490                                                 // Start of branch logic
491                                                 // %22 = in value
492         "OpBranch %24\n"
493         "%24 = OpLabel\n"
494         "OpBranch %25\n"
495         "%25 = OpLabel\n"
496         "OpBranch %26\n"
497         "%26 = OpLabel\n"
498         // %22 = out value
499         // End of branch logic
500         "OpStore %23 %22\n"
501         "OpReturn\n"
502         "OpFunctionEnd\n";
503 	// clang-format on
504 
505 	test(
506 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });
507 }
508 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,BranchDeclareSSA)509 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchDeclareSSA)
510 {
511 	std::stringstream src;
512 	// clang-format off
513     src <<
514         "OpCapability Shader\n"
515         "OpMemoryModel Logical GLSL450\n"
516         "OpEntryPoint GLCompute %1 \"main\" %2\n"
517         "OpExecutionMode %1 LocalSize " <<
518         GetParam().localSizeX << " " <<
519         GetParam().localSizeY << " " <<
520         GetParam().localSizeZ << "\n" <<
521         "OpDecorate %3 ArrayStride 4\n"
522         "OpMemberDecorate %4 0 Offset 0\n"
523         "OpDecorate %4 BufferBlock\n"
524         "OpDecorate %5 DescriptorSet 0\n"
525         "OpDecorate %5 Binding 1\n"
526         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
527         "OpDecorate %6 DescriptorSet 0\n"
528         "OpDecorate %6 Binding 0\n"
529         "%7 = OpTypeVoid\n"
530         "%8 = OpTypeFunction %7\n"             // void()
531         "%9 = OpTypeInt 32 1\n"                // int32
532         "%10 = OpTypeInt 32 0\n"                // uint32
533         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
534         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
535         "%11 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
536         "%5 = OpVariable %11 Uniform\n"        // struct{ int32[] }* in
537         "%12 = OpConstant %9 0\n"               // int32(0)
538         "%13 = OpConstant %10 0\n"              // uint32(0)
539         "%14 = OpTypeVector %10 3\n"            // vec3<int32>
540         "%15 = OpTypePointer Input %14\n"       // vec3<int32>*
541         "%2 = OpVariable %15 Input\n"          // gl_GlobalInvocationId
542         "%16 = OpTypePointer Input %10\n"       // uint32*
543         "%6 = OpVariable %11 Uniform\n"        // struct{ int32[] }* out
544         "%17 = OpTypePointer Uniform %9\n"      // int32*
545         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
546         "%18 = OpLabel\n"
547         "%19 = OpAccessChain %16 %2 %13\n"      // &gl_GlobalInvocationId.x
548         "%20 = OpLoad %10 %19\n"                // gl_GlobalInvocationId.x
549         "%21 = OpAccessChain %17 %6 %12 %20\n"  // &in.arr[gl_GlobalInvocationId.x]
550         "%22 = OpLoad %9 %21\n"                 // in.arr[gl_GlobalInvocationId.x]
551         "%23 = OpAccessChain %17 %5 %12 %20\n"  // &out.arr[gl_GlobalInvocationId.x]
552                                                 // Start of branch logic
553                                                 // %22 = in value
554         "OpBranch %24\n"
555         "%24 = OpLabel\n"
556         "%25 = OpIAdd %9 %22 %22\n"             // %25 = in*2
557         "OpBranch %26\n"
558         "%26 = OpLabel\n"
559         "OpBranch %27\n"
560         "%27 = OpLabel\n"
561         // %25 = out value
562         // End of branch logic
563         "OpStore %23 %25\n"               // use SSA value from previous block
564         "OpReturn\n"
565         "OpFunctionEnd\n";
566 	// clang-format on
567 
568 	test(
569 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i * 2; });
570 }
571 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,BranchConditionalSimple)572 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchConditionalSimple)
573 {
574 	std::stringstream src;
575 	// clang-format off
576     src <<
577         "OpCapability Shader\n"
578         "OpMemoryModel Logical GLSL450\n"
579         "OpEntryPoint GLCompute %1 \"main\" %2\n"
580         "OpExecutionMode %1 LocalSize " <<
581         GetParam().localSizeX << " " <<
582         GetParam().localSizeY << " " <<
583         GetParam().localSizeZ << "\n" <<
584         "OpDecorate %3 ArrayStride 4\n"
585         "OpMemberDecorate %4 0 Offset 0\n"
586         "OpDecorate %4 BufferBlock\n"
587         "OpDecorate %5 DescriptorSet 0\n"
588         "OpDecorate %5 Binding 1\n"
589         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
590         "OpDecorate %6 DescriptorSet 0\n"
591         "OpDecorate %6 Binding 0\n"
592         "%7 = OpTypeVoid\n"
593         "%8 = OpTypeFunction %7\n"             // void()
594         "%9 = OpTypeInt 32 1\n"                // int32
595         "%10 = OpTypeInt 32 0\n"                // uint32
596         "%11 = OpTypeBool\n"
597         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
598         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
599         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
600         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
601         "%13 = OpConstant %9 0\n"               // int32(0)
602         "%14 = OpConstant %9 2\n"               // int32(2)
603         "%15 = OpConstant %10 0\n"              // uint32(0)
604         "%16 = OpTypeVector %10 3\n"            // vec4<int32>
605         "%17 = OpTypePointer Input %16\n"       // vec4<int32>*
606         "%2 = OpVariable %17 Input\n"          // gl_GlobalInvocationId
607         "%18 = OpTypePointer Input %10\n"       // uint32*
608         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
609         "%19 = OpTypePointer Uniform %9\n"      // int32*
610         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
611         "%20 = OpLabel\n"
612         "%21 = OpAccessChain %18 %2 %15\n"      // &gl_GlobalInvocationId.x
613         "%22 = OpLoad %10 %21\n"                // gl_GlobalInvocationId.x
614         "%23 = OpAccessChain %19 %6 %13 %22\n"  // &in.arr[gl_GlobalInvocationId.x]
615         "%24 = OpLoad %9 %23\n"                 // in.arr[gl_GlobalInvocationId.x]
616         "%25 = OpAccessChain %19 %5 %13 %22\n"  // &out.arr[gl_GlobalInvocationId.x]
617                                                 // Start of branch logic
618                                                 // %24 = in value
619         "%26 = OpSMod %9 %24 %14\n"             // in % 2
620         "%27 = OpIEqual %11 %26 %13\n"          // (in % 2) == 0
621         "OpSelectionMerge %28 None\n"
622         "OpBranchConditional %27 %28 %28\n" // Both go to %28
623         "%28 = OpLabel\n"
624         // %26 = out value
625         // End of branch logic
626         "OpStore %25 %26\n"               // use SSA value from previous block
627         "OpReturn\n"
628         "OpFunctionEnd\n";
629 	// clang-format on
630 
631 	test(
632 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i % 2; });
633 }
634 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,BranchConditionalTwoEmptyBlocks)635 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchConditionalTwoEmptyBlocks)
636 {
637 	std::stringstream src;
638 	// clang-format off
639     src <<
640         "OpCapability Shader\n"
641         "OpMemoryModel Logical GLSL450\n"
642         "OpEntryPoint GLCompute %1 \"main\" %2\n"
643         "OpExecutionMode %1 LocalSize " <<
644         GetParam().localSizeX << " " <<
645         GetParam().localSizeY << " " <<
646         GetParam().localSizeZ << "\n" <<
647         "OpDecorate %3 ArrayStride 4\n"
648         "OpMemberDecorate %4 0 Offset 0\n"
649         "OpDecorate %4 BufferBlock\n"
650         "OpDecorate %5 DescriptorSet 0\n"
651         "OpDecorate %5 Binding 1\n"
652         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
653         "OpDecorate %6 DescriptorSet 0\n"
654         "OpDecorate %6 Binding 0\n"
655         "%7 = OpTypeVoid\n"
656         "%8 = OpTypeFunction %7\n"             // void()
657         "%9 = OpTypeInt 32 1\n"                // int32
658         "%10 = OpTypeInt 32 0\n"                // uint32
659         "%11 = OpTypeBool\n"
660         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
661         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
662         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
663         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
664         "%13 = OpConstant %9 0\n"               // int32(0)
665         "%14 = OpConstant %9 2\n"               // int32(2)
666         "%15 = OpConstant %10 0\n"              // uint32(0)
667         "%16 = OpTypeVector %10 3\n"            // vec4<int32>
668         "%17 = OpTypePointer Input %16\n"       // vec4<int32>*
669         "%2 = OpVariable %17 Input\n"          // gl_GlobalInvocationId
670         "%18 = OpTypePointer Input %10\n"       // uint32*
671         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
672         "%19 = OpTypePointer Uniform %9\n"      // int32*
673         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
674         "%20 = OpLabel\n"
675         "%21 = OpAccessChain %18 %2 %15\n"      // &gl_GlobalInvocationId.x
676         "%22 = OpLoad %10 %21\n"                // gl_GlobalInvocationId.x
677         "%23 = OpAccessChain %19 %6 %13 %22\n"  // &in.arr[gl_GlobalInvocationId.x]
678         "%24 = OpLoad %9 %23\n"                 // in.arr[gl_GlobalInvocationId.x]
679         "%25 = OpAccessChain %19 %5 %13 %22\n"  // &out.arr[gl_GlobalInvocationId.x]
680                                                 // Start of branch logic
681                                                 // %24 = in value
682         "%26 = OpSMod %9 %24 %14\n"             // in % 2
683         "%27 = OpIEqual %11 %26 %13\n"          // (in % 2) == 0
684         "OpSelectionMerge %28 None\n"
685         "OpBranchConditional %27 %29 %30\n"
686         "%29 = OpLabel\n"                       // (in % 2) == 0
687         "OpBranch %28\n"
688         "%30 = OpLabel\n"                       // (in % 2) != 0
689         "OpBranch %28\n"
690         "%28 = OpLabel\n"
691         // %26 = out value
692         // End of branch logic
693         "OpStore %25 %26\n"               // use SSA value from previous block
694         "OpReturn\n"
695         "OpFunctionEnd\n";
696 	// clang-format on
697 
698 	test(
699 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i % 2; });
700 }
701 
702 // TODO: Test for parallel assignment
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,BranchConditionalStore)703 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchConditionalStore)
704 {
705 	std::stringstream src;
706 	// clang-format off
707     src <<
708         "OpCapability Shader\n"
709         "OpMemoryModel Logical GLSL450\n"
710         "OpEntryPoint GLCompute %1 \"main\" %2\n"
711         "OpExecutionMode %1 LocalSize " <<
712         GetParam().localSizeX << " " <<
713         GetParam().localSizeY << " " <<
714         GetParam().localSizeZ << "\n" <<
715         "OpDecorate %3 ArrayStride 4\n"
716         "OpMemberDecorate %4 0 Offset 0\n"
717         "OpDecorate %4 BufferBlock\n"
718         "OpDecorate %5 DescriptorSet 0\n"
719         "OpDecorate %5 Binding 1\n"
720         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
721         "OpDecorate %6 DescriptorSet 0\n"
722         "OpDecorate %6 Binding 0\n"
723         "%7 = OpTypeVoid\n"
724         "%8 = OpTypeFunction %7\n"             // void()
725         "%9 = OpTypeInt 32 1\n"                // int32
726         "%10 = OpTypeInt 32 0\n"                // uint32
727         "%11 = OpTypeBool\n"
728         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
729         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
730         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
731         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
732         "%13 = OpConstant %9 0\n"               // int32(0)
733         "%14 = OpConstant %9 1\n"               // int32(1)
734         "%15 = OpConstant %9 2\n"               // int32(2)
735         "%16 = OpConstant %10 0\n"              // uint32(0)
736         "%17 = OpTypeVector %10 3\n"            // vec4<int32>
737         "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
738         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
739         "%19 = OpTypePointer Input %10\n"       // uint32*
740         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
741         "%20 = OpTypePointer Uniform %9\n"      // int32*
742         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
743         "%21 = OpLabel\n"
744         "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
745         "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
746         "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
747         "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
748         "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
749                                                 // Start of branch logic
750                                                 // %25 = in value
751         "%27 = OpSMod %9 %25 %15\n"             // in % 2
752         "%28 = OpIEqual %11 %27 %13\n"          // (in % 2) == 0
753         "OpSelectionMerge %29 None\n"
754         "OpBranchConditional %28 %30 %31\n"
755         "%30 = OpLabel\n"                       // (in % 2) == 0
756         "OpStore %26 %14\n"               // write 1
757         "OpBranch %29\n"
758         "%31 = OpLabel\n"                       // (in % 2) != 0
759         "OpStore %26 %15\n"               // write 2
760         "OpBranch %29\n"
761         "%29 = OpLabel\n"
762         // End of branch logic
763         "OpReturn\n"
764         "OpFunctionEnd\n";
765 	// clang-format on
766 
767 	test(
768 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 0 ? 1 : 2; });
769 }
770 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,BranchConditionalReturnTrue)771 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchConditionalReturnTrue)
772 {
773 	std::stringstream src;
774 	// clang-format off
775     src <<
776         "OpCapability Shader\n"
777         "OpMemoryModel Logical GLSL450\n"
778         "OpEntryPoint GLCompute %1 \"main\" %2\n"
779         "OpExecutionMode %1 LocalSize " <<
780         GetParam().localSizeX << " " <<
781         GetParam().localSizeY << " " <<
782         GetParam().localSizeZ << "\n" <<
783         "OpDecorate %3 ArrayStride 4\n"
784         "OpMemberDecorate %4 0 Offset 0\n"
785         "OpDecorate %4 BufferBlock\n"
786         "OpDecorate %5 DescriptorSet 0\n"
787         "OpDecorate %5 Binding 1\n"
788         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
789         "OpDecorate %6 DescriptorSet 0\n"
790         "OpDecorate %6 Binding 0\n"
791         "%7 = OpTypeVoid\n"
792         "%8 = OpTypeFunction %7\n"             // void()
793         "%9 = OpTypeInt 32 1\n"                // int32
794         "%10 = OpTypeInt 32 0\n"                // uint32
795         "%11 = OpTypeBool\n"
796         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
797         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
798         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
799         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
800         "%13 = OpConstant %9 0\n"               // int32(0)
801         "%14 = OpConstant %9 1\n"               // int32(1)
802         "%15 = OpConstant %9 2\n"               // int32(2)
803         "%16 = OpConstant %10 0\n"              // uint32(0)
804         "%17 = OpTypeVector %10 3\n"            // vec4<int32>
805         "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
806         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
807         "%19 = OpTypePointer Input %10\n"       // uint32*
808         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
809         "%20 = OpTypePointer Uniform %9\n"      // int32*
810         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
811         "%21 = OpLabel\n"
812         "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
813         "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
814         "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
815         "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
816         "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
817                                                 // Start of branch logic
818                                                 // %25 = in value
819         "%27 = OpSMod %9 %25 %15\n"             // in % 2
820         "%28 = OpIEqual %11 %27 %13\n"          // (in % 2) == 0
821         "OpSelectionMerge %29 None\n"
822         "OpBranchConditional %28 %30 %29\n"
823         "%30 = OpLabel\n"                       // (in % 2) == 0
824         "OpReturn\n"
825         "%29 = OpLabel\n"                       // merge
826         "OpStore %26 %15\n"               // write 2
827                                           // End of branch logic
828         "OpReturn\n"
829         "OpFunctionEnd\n";
830 	// clang-format on
831 
832 	test(
833 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 0 ? 0 : 2; });
834 }
835 
836 // TODO: Test for parallel assignment
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,BranchConditionalPhi)837 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchConditionalPhi)
838 {
839 	std::stringstream src;
840 	// clang-format off
841     src <<
842         "OpCapability Shader\n"
843         "OpMemoryModel Logical GLSL450\n"
844         "OpEntryPoint GLCompute %1 \"main\" %2\n"
845         "OpExecutionMode %1 LocalSize " <<
846         GetParam().localSizeX << " " <<
847         GetParam().localSizeY << " " <<
848         GetParam().localSizeZ << "\n" <<
849         "OpDecorate %3 ArrayStride 4\n"
850         "OpMemberDecorate %4 0 Offset 0\n"
851         "OpDecorate %4 BufferBlock\n"
852         "OpDecorate %5 DescriptorSet 0\n"
853         "OpDecorate %5 Binding 1\n"
854         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
855         "OpDecorate %6 DescriptorSet 0\n"
856         "OpDecorate %6 Binding 0\n"
857         "%7 = OpTypeVoid\n"
858         "%8 = OpTypeFunction %7\n"             // void()
859         "%9 = OpTypeInt 32 1\n"                // int32
860         "%10 = OpTypeInt 32 0\n"                // uint32
861         "%11 = OpTypeBool\n"
862         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
863         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
864         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
865         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
866         "%13 = OpConstant %9 0\n"               // int32(0)
867         "%14 = OpConstant %9 1\n"               // int32(1)
868         "%15 = OpConstant %9 2\n"               // int32(2)
869         "%16 = OpConstant %10 0\n"              // uint32(0)
870         "%17 = OpTypeVector %10 3\n"            // vec4<int32>
871         "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
872         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
873         "%19 = OpTypePointer Input %10\n"       // uint32*
874         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
875         "%20 = OpTypePointer Uniform %9\n"      // int32*
876         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
877         "%21 = OpLabel\n"
878         "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
879         "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
880         "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
881         "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
882         "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
883                                                 // Start of branch logic
884                                                 // %25 = in value
885         "%27 = OpSMod %9 %25 %15\n"             // in % 2
886         "%28 = OpIEqual %11 %27 %13\n"          // (in % 2) == 0
887         "OpSelectionMerge %29 None\n"
888         "OpBranchConditional %28 %30 %31\n"
889         "%30 = OpLabel\n"                       // (in % 2) == 0
890         "OpBranch %29\n"
891         "%31 = OpLabel\n"                       // (in % 2) != 0
892         "OpBranch %29\n"
893         "%29 = OpLabel\n"
894         "%32 = OpPhi %9 %14 %30 %15 %31\n"      // (in % 2) == 0 ? 1 : 2
895                                                 // End of branch logic
896         "OpStore %26 %32\n"
897         "OpReturn\n"
898         "OpFunctionEnd\n";
899 	// clang-format on
900 
901 	test(
902 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 0 ? 1 : 2; });
903 }
904 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,SwitchEmptyCases)905 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchEmptyCases)
906 {
907 	std::stringstream src;
908 	// clang-format off
909     src <<
910         "OpCapability Shader\n"
911         "OpMemoryModel Logical GLSL450\n"
912         "OpEntryPoint GLCompute %1 \"main\" %2\n"
913         "OpExecutionMode %1 LocalSize " <<
914         GetParam().localSizeX << " " <<
915         GetParam().localSizeY << " " <<
916         GetParam().localSizeZ << "\n" <<
917         "OpDecorate %3 ArrayStride 4\n"
918         "OpMemberDecorate %4 0 Offset 0\n"
919         "OpDecorate %4 BufferBlock\n"
920         "OpDecorate %5 DescriptorSet 0\n"
921         "OpDecorate %5 Binding 1\n"
922         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
923         "OpDecorate %6 DescriptorSet 0\n"
924         "OpDecorate %6 Binding 0\n"
925         "%7 = OpTypeVoid\n"
926         "%8 = OpTypeFunction %7\n"             // void()
927         "%9 = OpTypeInt 32 1\n"                // int32
928         "%10 = OpTypeInt 32 0\n"                // uint32
929         "%11 = OpTypeBool\n"
930         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
931         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
932         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
933         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
934         "%13 = OpConstant %9 0\n"               // int32(0)
935         "%14 = OpConstant %9 2\n"               // int32(2)
936         "%15 = OpConstant %10 0\n"              // uint32(0)
937         "%16 = OpTypeVector %10 3\n"            // vec4<int32>
938         "%17 = OpTypePointer Input %16\n"       // vec4<int32>*
939         "%2 = OpVariable %17 Input\n"          // gl_GlobalInvocationId
940         "%18 = OpTypePointer Input %10\n"       // uint32*
941         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
942         "%19 = OpTypePointer Uniform %9\n"      // int32*
943         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
944         "%20 = OpLabel\n"
945         "%21 = OpAccessChain %18 %2 %15\n"      // &gl_GlobalInvocationId.x
946         "%22 = OpLoad %10 %21\n"                // gl_GlobalInvocationId.x
947         "%23 = OpAccessChain %19 %6 %13 %22\n"  // &in.arr[gl_GlobalInvocationId.x]
948         "%24 = OpLoad %9 %23\n"                 // in.arr[gl_GlobalInvocationId.x]
949         "%25 = OpAccessChain %19 %5 %13 %22\n"  // &out.arr[gl_GlobalInvocationId.x]
950                                                 // Start of branch logic
951                                                 // %24 = in value
952         "%26 = OpSMod %9 %24 %14\n"             // in % 2
953         "OpSelectionMerge %27 None\n"
954         "OpSwitch %26 %27 0 %28 1 %29\n"
955         "%28 = OpLabel\n"                       // (in % 2) == 0
956         "OpBranch %27\n"
957         "%29 = OpLabel\n"                       // (in % 2) == 1
958         "OpBranch %27\n"
959         "%27 = OpLabel\n"
960         // %26 = out value
961         // End of branch logic
962         "OpStore %25 %26\n"               // use SSA value from previous block
963         "OpReturn\n"
964         "OpFunctionEnd\n";
965 	// clang-format on
966 
967 	test(
968 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i % 2; });
969 }
970 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,SwitchStore)971 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchStore)
972 {
973 	std::stringstream src;
974 	// clang-format off
975     src <<
976         "OpCapability Shader\n"
977         "OpMemoryModel Logical GLSL450\n"
978         "OpEntryPoint GLCompute %1 \"main\" %2\n"
979         "OpExecutionMode %1 LocalSize " <<
980         GetParam().localSizeX << " " <<
981         GetParam().localSizeY << " " <<
982         GetParam().localSizeZ << "\n" <<
983         "OpDecorate %3 ArrayStride 4\n"
984         "OpMemberDecorate %4 0 Offset 0\n"
985         "OpDecorate %4 BufferBlock\n"
986         "OpDecorate %5 DescriptorSet 0\n"
987         "OpDecorate %5 Binding 1\n"
988         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
989         "OpDecorate %6 DescriptorSet 0\n"
990         "OpDecorate %6 Binding 0\n"
991         "%7 = OpTypeVoid\n"
992         "%8 = OpTypeFunction %7\n"             // void()
993         "%9 = OpTypeInt 32 1\n"                // int32
994         "%10 = OpTypeInt 32 0\n"                // uint32
995         "%11 = OpTypeBool\n"
996         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
997         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
998         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
999         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
1000         "%13 = OpConstant %9 0\n"               // int32(0)
1001         "%14 = OpConstant %9 1\n"               // int32(1)
1002         "%15 = OpConstant %9 2\n"               // int32(2)
1003         "%16 = OpConstant %10 0\n"              // uint32(0)
1004         "%17 = OpTypeVector %10 3\n"            // vec4<int32>
1005         "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
1006         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
1007         "%19 = OpTypePointer Input %10\n"       // uint32*
1008         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
1009         "%20 = OpTypePointer Uniform %9\n"      // int32*
1010         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
1011         "%21 = OpLabel\n"
1012         "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
1013         "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
1014         "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
1015         "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
1016         "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
1017                                                 // Start of branch logic
1018                                                 // %25 = in value
1019         "%27 = OpSMod %9 %25 %15\n"             // in % 2
1020         "OpSelectionMerge %28 None\n"
1021         "OpSwitch %27 %28 0 %29 1 %30\n"
1022         "%29 = OpLabel\n"                       // (in % 2) == 0
1023         "OpStore %26 %15\n"               // write 2
1024         "OpBranch %28\n"
1025         "%30 = OpLabel\n"                       // (in % 2) == 1
1026         "OpStore %26 %14\n"               // write 1
1027         "OpBranch %28\n"
1028         "%28 = OpLabel\n"
1029         // End of branch logic
1030         "OpReturn\n"
1031         "OpFunctionEnd\n";
1032 	// clang-format on
1033 
1034 	test(
1035 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 0 ? 2 : 1; });
1036 }
1037 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,SwitchCaseReturn)1038 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchCaseReturn)
1039 {
1040 	std::stringstream src;
1041 	// clang-format off
1042     src <<
1043         "OpCapability Shader\n"
1044         "OpMemoryModel Logical GLSL450\n"
1045         "OpEntryPoint GLCompute %1 \"main\" %2\n"
1046         "OpExecutionMode %1 LocalSize " <<
1047         GetParam().localSizeX << " " <<
1048         GetParam().localSizeY << " " <<
1049         GetParam().localSizeZ << "\n" <<
1050         "OpDecorate %3 ArrayStride 4\n"
1051         "OpMemberDecorate %4 0 Offset 0\n"
1052         "OpDecorate %4 BufferBlock\n"
1053         "OpDecorate %5 DescriptorSet 0\n"
1054         "OpDecorate %5 Binding 1\n"
1055         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
1056         "OpDecorate %6 DescriptorSet 0\n"
1057         "OpDecorate %6 Binding 0\n"
1058         "%7 = OpTypeVoid\n"
1059         "%8 = OpTypeFunction %7\n"             // void()
1060         "%9 = OpTypeInt 32 1\n"                // int32
1061         "%10 = OpTypeInt 32 0\n"                // uint32
1062         "%11 = OpTypeBool\n"
1063         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
1064         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
1065         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
1066         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
1067         "%13 = OpConstant %9 0\n"               // int32(0)
1068         "%14 = OpConstant %9 1\n"               // int32(1)
1069         "%15 = OpConstant %9 2\n"               // int32(2)
1070         "%16 = OpConstant %10 0\n"              // uint32(0)
1071         "%17 = OpTypeVector %10 3\n"            // vec4<int32>
1072         "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
1073         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
1074         "%19 = OpTypePointer Input %10\n"       // uint32*
1075         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
1076         "%20 = OpTypePointer Uniform %9\n"      // int32*
1077         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
1078         "%21 = OpLabel\n"
1079         "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
1080         "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
1081         "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
1082         "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
1083         "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
1084                                                 // Start of branch logic
1085                                                 // %25 = in value
1086         "%27 = OpSMod %9 %25 %15\n"             // in % 2
1087         "OpSelectionMerge %28 None\n"
1088         "OpSwitch %27 %28 0 %29 1 %30\n"
1089         "%29 = OpLabel\n"                       // (in % 2) == 0
1090         "OpBranch %28\n"
1091         "%30 = OpLabel\n"                       // (in % 2) == 1
1092         "OpReturn\n"
1093         "%28 = OpLabel\n"
1094         "OpStore %26 %14\n"               // write 1
1095                                           // End of branch logic
1096         "OpReturn\n"
1097         "OpFunctionEnd\n";
1098 	// clang-format on
1099 
1100 	test(
1101 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 0 : 1; });
1102 }
1103 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,SwitchDefaultReturn)1104 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchDefaultReturn)
1105 {
1106 	std::stringstream src;
1107 	// clang-format off
1108     src <<
1109         "OpCapability Shader\n"
1110         "OpMemoryModel Logical GLSL450\n"
1111         "OpEntryPoint GLCompute %1 \"main\" %2\n"
1112         "OpExecutionMode %1 LocalSize " <<
1113         GetParam().localSizeX << " " <<
1114         GetParam().localSizeY << " " <<
1115         GetParam().localSizeZ << "\n" <<
1116         "OpDecorate %3 ArrayStride 4\n"
1117         "OpMemberDecorate %4 0 Offset 0\n"
1118         "OpDecorate %4 BufferBlock\n"
1119         "OpDecorate %5 DescriptorSet 0\n"
1120         "OpDecorate %5 Binding 1\n"
1121         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
1122         "OpDecorate %6 DescriptorSet 0\n"
1123         "OpDecorate %6 Binding 0\n"
1124         "%7 = OpTypeVoid\n"
1125         "%8 = OpTypeFunction %7\n"             // void()
1126         "%9 = OpTypeInt 32 1\n"                // int32
1127         "%10 = OpTypeInt 32 0\n"                // uint32
1128         "%11 = OpTypeBool\n"
1129         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
1130         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
1131         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
1132         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
1133         "%13 = OpConstant %9 0\n"               // int32(0)
1134         "%14 = OpConstant %9 1\n"               // int32(1)
1135         "%15 = OpConstant %9 2\n"               // int32(2)
1136         "%16 = OpConstant %10 0\n"              // uint32(0)
1137         "%17 = OpTypeVector %10 3\n"            // vec4<int32>
1138         "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
1139         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
1140         "%19 = OpTypePointer Input %10\n"       // uint32*
1141         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
1142         "%20 = OpTypePointer Uniform %9\n"      // int32*
1143         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
1144         "%21 = OpLabel\n"
1145         "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
1146         "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
1147         "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
1148         "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
1149         "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
1150                                                 // Start of branch logic
1151                                                 // %25 = in value
1152         "%27 = OpSMod %9 %25 %15\n"             // in % 2
1153         "OpSelectionMerge %28 None\n"
1154         "OpSwitch %27 %29 1 %30\n"
1155         "%30 = OpLabel\n"                       // (in % 2) == 1
1156         "OpBranch %28\n"
1157         "%29 = OpLabel\n"                       // (in % 2) != 1
1158         "OpReturn\n"
1159         "%28 = OpLabel\n"                       // merge
1160         "OpStore %26 %14\n"               // write 1
1161                                           // End of branch logic
1162         "OpReturn\n"
1163         "OpFunctionEnd\n";
1164 	// clang-format on
1165 
1166 	test(
1167 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 1 : 0; });
1168 }
1169 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,SwitchCaseFallthrough)1170 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchCaseFallthrough)
1171 {
1172 	std::stringstream src;
1173 	// clang-format off
1174     src <<
1175         "OpCapability Shader\n"
1176         "OpMemoryModel Logical GLSL450\n"
1177         "OpEntryPoint GLCompute %1 \"main\" %2\n"
1178         "OpExecutionMode %1 LocalSize " <<
1179         GetParam().localSizeX << " " <<
1180         GetParam().localSizeY << " " <<
1181         GetParam().localSizeZ << "\n" <<
1182         "OpDecorate %3 ArrayStride 4\n"
1183         "OpMemberDecorate %4 0 Offset 0\n"
1184         "OpDecorate %4 BufferBlock\n"
1185         "OpDecorate %5 DescriptorSet 0\n"
1186         "OpDecorate %5 Binding 1\n"
1187         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
1188         "OpDecorate %6 DescriptorSet 0\n"
1189         "OpDecorate %6 Binding 0\n"
1190         "%7 = OpTypeVoid\n"
1191         "%8 = OpTypeFunction %7\n"             // void()
1192         "%9 = OpTypeInt 32 1\n"                // int32
1193         "%10 = OpTypeInt 32 0\n"                // uint32
1194         "%11 = OpTypeBool\n"
1195         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
1196         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
1197         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
1198         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
1199         "%13 = OpConstant %9 0\n"               // int32(0)
1200         "%14 = OpConstant %9 1\n"               // int32(1)
1201         "%15 = OpConstant %9 2\n"               // int32(2)
1202         "%16 = OpConstant %10 0\n"              // uint32(0)
1203         "%17 = OpTypeVector %10 3\n"            // vec4<int32>
1204         "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
1205         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
1206         "%19 = OpTypePointer Input %10\n"       // uint32*
1207         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
1208         "%20 = OpTypePointer Uniform %9\n"      // int32*
1209         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
1210         "%21 = OpLabel\n"
1211         "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
1212         "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
1213         "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
1214         "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
1215         "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
1216                                                 // Start of branch logic
1217                                                 // %25 = in value
1218         "%27 = OpSMod %9 %25 %15\n"             // in % 2
1219         "OpSelectionMerge %28 None\n"
1220         "OpSwitch %27 %29 0 %30 1 %31\n"
1221         "%30 = OpLabel\n"                       // (in % 2) == 0
1222         "%32 = OpIAdd %9 %27 %14\n"             // generate an intermediate
1223         "OpStore %26 %32\n"               // write a value (overwritten later)
1224         "OpBranch %31\n"                  // fallthrough
1225         "%31 = OpLabel\n"                       // (in % 2) == 1
1226         "OpStore %26 %15\n"               // write 2
1227         "OpBranch %28\n"
1228         "%29 = OpLabel\n"                       // unreachable
1229         "OpUnreachable\n"
1230         "%28 = OpLabel\n"                       // merge
1231                                                 // End of branch logic
1232         "OpReturn\n"
1233         "OpFunctionEnd\n";
1234 	// clang-format on
1235 
1236 	test(
1237 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return 2; });
1238 }
1239 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,SwitchDefaultFallthrough)1240 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchDefaultFallthrough)
1241 {
1242 	std::stringstream src;
1243 	// clang-format off
1244     src <<
1245         "OpCapability Shader\n"
1246         "OpMemoryModel Logical GLSL450\n"
1247         "OpEntryPoint GLCompute %1 \"main\" %2\n"
1248         "OpExecutionMode %1 LocalSize " <<
1249         GetParam().localSizeX << " " <<
1250         GetParam().localSizeY << " " <<
1251         GetParam().localSizeZ << "\n" <<
1252         "OpDecorate %3 ArrayStride 4\n"
1253         "OpMemberDecorate %4 0 Offset 0\n"
1254         "OpDecorate %4 BufferBlock\n"
1255         "OpDecorate %5 DescriptorSet 0\n"
1256         "OpDecorate %5 Binding 1\n"
1257         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
1258         "OpDecorate %6 DescriptorSet 0\n"
1259         "OpDecorate %6 Binding 0\n"
1260         "%7 = OpTypeVoid\n"
1261         "%8 = OpTypeFunction %7\n"             // void()
1262         "%9 = OpTypeInt 32 1\n"                // int32
1263         "%10 = OpTypeInt 32 0\n"                // uint32
1264         "%11 = OpTypeBool\n"
1265         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
1266         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
1267         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
1268         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
1269         "%13 = OpConstant %9 0\n"               // int32(0)
1270         "%14 = OpConstant %9 1\n"               // int32(1)
1271         "%15 = OpConstant %9 2\n"               // int32(2)
1272         "%16 = OpConstant %10 0\n"              // uint32(0)
1273         "%17 = OpTypeVector %10 3\n"            // vec4<int32>
1274         "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
1275         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
1276         "%19 = OpTypePointer Input %10\n"       // uint32*
1277         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
1278         "%20 = OpTypePointer Uniform %9\n"      // int32*
1279         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
1280         "%21 = OpLabel\n"
1281         "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
1282         "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
1283         "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
1284         "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
1285         "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
1286                                                 // Start of branch logic
1287                                                 // %25 = in value
1288         "%27 = OpSMod %9 %25 %15\n"             // in % 2
1289         "OpSelectionMerge %28 None\n"
1290         "OpSwitch %27 %29 0 %30 1 %31\n"
1291         "%30 = OpLabel\n"                       // (in % 2) == 0
1292         "%32 = OpIAdd %9 %27 %14\n"             // generate an intermediate
1293         "OpStore %26 %32\n"               // write a value (overwritten later)
1294         "OpBranch %29\n"                  // fallthrough
1295         "%29 = OpLabel\n"                       // default
1296         "%33 = OpIAdd %9 %27 %14\n"             // generate an intermediate
1297         "OpStore %26 %33\n"               // write a value (overwritten later)
1298         "OpBranch %31\n"                  // fallthrough
1299         "%31 = OpLabel\n"                       // (in % 2) == 1
1300         "OpStore %26 %15\n"               // write 2
1301         "OpBranch %28\n"
1302         "%28 = OpLabel\n"                       // merge
1303                                                 // End of branch logic
1304         "OpReturn\n"
1305         "OpFunctionEnd\n";
1306 	// clang-format on
1307 
1308 	test(
1309 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return 2; });
1310 }
1311 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,SwitchPhi)1312 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchPhi)
1313 {
1314 	std::stringstream src;
1315 	// clang-format off
1316     src <<
1317         "OpCapability Shader\n"
1318         "OpMemoryModel Logical GLSL450\n"
1319         "OpEntryPoint GLCompute %1 \"main\" %2\n"
1320         "OpExecutionMode %1 LocalSize " <<
1321         GetParam().localSizeX << " " <<
1322         GetParam().localSizeY << " " <<
1323         GetParam().localSizeZ << "\n" <<
1324         "OpDecorate %3 ArrayStride 4\n"
1325         "OpMemberDecorate %4 0 Offset 0\n"
1326         "OpDecorate %4 BufferBlock\n"
1327         "OpDecorate %5 DescriptorSet 0\n"
1328         "OpDecorate %5 Binding 1\n"
1329         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
1330         "OpDecorate %6 DescriptorSet 0\n"
1331         "OpDecorate %6 Binding 0\n"
1332         "%7 = OpTypeVoid\n"
1333         "%8 = OpTypeFunction %7\n"             // void()
1334         "%9 = OpTypeInt 32 1\n"                // int32
1335         "%10 = OpTypeInt 32 0\n"                // uint32
1336         "%11 = OpTypeBool\n"
1337         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
1338         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
1339         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
1340         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
1341         "%13 = OpConstant %9 0\n"               // int32(0)
1342         "%14 = OpConstant %9 1\n"               // int32(1)
1343         "%15 = OpConstant %9 2\n"               // int32(2)
1344         "%16 = OpConstant %10 0\n"              // uint32(0)
1345         "%17 = OpTypeVector %10 3\n"            // vec4<int32>
1346         "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
1347         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
1348         "%19 = OpTypePointer Input %10\n"       // uint32*
1349         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
1350         "%20 = OpTypePointer Uniform %9\n"      // int32*
1351         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
1352         "%21 = OpLabel\n"
1353         "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
1354         "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
1355         "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
1356         "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
1357         "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
1358                                                 // Start of branch logic
1359                                                 // %25 = in value
1360         "%27 = OpSMod %9 %25 %15\n"             // in % 2
1361         "OpSelectionMerge %28 None\n"
1362         "OpSwitch %27 %29 1 %30\n"
1363         "%30 = OpLabel\n"                       // (in % 2) == 1
1364         "OpBranch %28\n"
1365         "%29 = OpLabel\n"                       // (in % 2) != 1
1366         "OpBranch %28\n"
1367         "%28 = OpLabel\n"                       // merge
1368         "%31 = OpPhi %9 %14 %30 %15 %29\n"      // (in % 2) == 1 ? 1 : 2
1369         "OpStore %26 %31\n"
1370         // End of branch logic
1371         "OpReturn\n"
1372         "OpFunctionEnd\n";
1373 	// clang-format on
1374 
1375 	test(
1376 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 1 : 2; });
1377 }
1378 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,LoopDivergentMergePhi)1379 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, LoopDivergentMergePhi)
1380 {
1381 	// #version 450
1382 	// layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
1383 	// layout(binding = 0, std430) buffer InBuffer
1384 	// {
1385 	//     int Data[];
1386 	// } In;
1387 	// layout(binding = 1, std430) buffer OutBuffer
1388 	// {
1389 	//     int Data[];
1390 	// } Out;
1391 	// void main()
1392 	// {
1393 	//     int phi = 0;
1394 	//     uint lane = gl_GlobalInvocationID.x % 4;
1395 	//     for (uint i = 0; i < 4; i++)
1396 	//     {
1397 	//         if (lane == i)
1398 	//         {
1399 	//             phi = In.Data[gl_GlobalInvocationID.x];
1400 	//             break;
1401 	//         }
1402 	//     }
1403 	//     Out.Data[gl_GlobalInvocationID.x] = phi;
1404 	// }
1405 	std::stringstream src;
1406 	// clang-format off
1407     src <<
1408         "OpCapability Shader\n"
1409         "%1 = OpExtInstImport \"GLSL.std.450\"\n"
1410         "OpMemoryModel Logical GLSL450\n"
1411         "OpEntryPoint GLCompute %2 \"main\" %3\n"
1412         "OpExecutionMode %2 LocalSize " <<
1413         GetParam().localSizeX << " " <<
1414         GetParam().localSizeY << " " <<
1415         GetParam().localSizeZ << "\n" <<
1416         "OpDecorate %3 BuiltIn GlobalInvocationId\n"
1417         "OpDecorate %4 ArrayStride 4\n"
1418         "OpMemberDecorate %5 0 Offset 0\n"
1419         "OpDecorate %5 BufferBlock\n"
1420         "OpDecorate %6 DescriptorSet 0\n"
1421         "OpDecorate %6 Binding 0\n"
1422         "OpDecorate %7 ArrayStride 4\n"
1423         "OpMemberDecorate %8 0 Offset 0\n"
1424         "OpDecorate %8 BufferBlock\n"
1425         "OpDecorate %9 DescriptorSet 0\n"
1426         "OpDecorate %9 Binding 1\n"
1427         "%10 = OpTypeVoid\n"
1428         "%11 = OpTypeFunction %10\n"
1429         "%12 = OpTypeInt 32 1\n"
1430         "%13 = OpConstant %12 0\n"
1431         "%14 = OpTypeInt 32 0\n"
1432         "%15 = OpTypeVector %14 3\n"
1433         "%16 = OpTypePointer Input %15\n"
1434         "%3 = OpVariable %16 Input\n"
1435         "%17 = OpConstant %14 0\n"
1436         "%18 = OpTypePointer Input %14\n"
1437         "%19 = OpConstant %14 4\n"
1438         "%20 = OpTypeBool\n"
1439         "%4 = OpTypeRuntimeArray %12\n"
1440         "%5 = OpTypeStruct %4\n"
1441         "%21 = OpTypePointer Uniform %5\n"
1442         "%6 = OpVariable %21 Uniform\n"
1443         "%22 = OpTypePointer Uniform %12\n"
1444         "%23 = OpConstant %12 1\n"
1445         "%7 = OpTypeRuntimeArray %12\n"
1446         "%8 = OpTypeStruct %7\n"
1447         "%24 = OpTypePointer Uniform %8\n"
1448         "%9 = OpVariable %24 Uniform\n"
1449         "%2 = OpFunction %10 None %11\n"
1450         "%25 = OpLabel\n"
1451         "%26 = OpAccessChain %18 %3 %17\n"
1452         "%27 = OpLoad %14 %26\n"
1453         "%28 = OpUMod %14 %27 %19\n"
1454         "OpBranch %29\n"
1455         "%29 = OpLabel\n"
1456         "%30 = OpPhi %14 %17 %25 %31 %32\n"
1457         "%33 = OpULessThan %20 %30 %19\n"
1458         "OpLoopMerge %34 %32 None\n"
1459         "OpBranchConditional %33 %35 %34\n"
1460         "%35 = OpLabel\n"
1461         "%36 = OpIEqual %20 %28 %30\n"
1462         "OpSelectionMerge %37 None\n"
1463         "OpBranchConditional %36 %38 %37\n"
1464         "%38 = OpLabel\n"
1465         "%39 = OpAccessChain %22 %6 %13 %27\n"
1466         "%40 = OpLoad %12 %39\n"
1467         "OpBranch %34\n"
1468         "%37 = OpLabel\n"
1469         "OpBranch %32\n"
1470         "%32 = OpLabel\n"
1471         "%31 = OpIAdd %14 %30 %23\n"
1472         "OpBranch %29\n"
1473         "%34 = OpLabel\n"
1474         "%41 = OpPhi %12 %13 %29 %40 %38\n" // %40: phi
1475         "%42 = OpAccessChain %22 %9 %13 %27\n"
1476         "OpStore %42 %41\n"
1477         "OpReturn\n"
1478         "OpFunctionEnd\n";
1479 	// clang-format on
1480 
1481 	test(
1482 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });
1483 }
1484