xref: /aosp_15_r20/external/swiftshader/src/Pipeline/ComputeProgram.cpp (revision 03ce13f70fcc45d86ee91b7ee4cab1936a95046e)
1*03ce13f7SAndroid Build Coastguard Worker // Copyright 2019 The SwiftShader Authors. All Rights Reserved.
2*03ce13f7SAndroid Build Coastguard Worker //
3*03ce13f7SAndroid Build Coastguard Worker // Licensed under the Apache License, Version 2.0 (the "License");
4*03ce13f7SAndroid Build Coastguard Worker // you may not use this file except in compliance with the License.
5*03ce13f7SAndroid Build Coastguard Worker // You may obtain a copy of the License at
6*03ce13f7SAndroid Build Coastguard Worker //
7*03ce13f7SAndroid Build Coastguard Worker //    http://www.apache.org/licenses/LICENSE-2.0
8*03ce13f7SAndroid Build Coastguard Worker //
9*03ce13f7SAndroid Build Coastguard Worker // Unless required by applicable law or agreed to in writing, software
10*03ce13f7SAndroid Build Coastguard Worker // distributed under the License is distributed on an "AS IS" BASIS,
11*03ce13f7SAndroid Build Coastguard Worker // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*03ce13f7SAndroid Build Coastguard Worker // See the License for the specific language governing permissions and
13*03ce13f7SAndroid Build Coastguard Worker // limitations under the License.
14*03ce13f7SAndroid Build Coastguard Worker 
15*03ce13f7SAndroid Build Coastguard Worker #include "ComputeProgram.hpp"
16*03ce13f7SAndroid Build Coastguard Worker 
17*03ce13f7SAndroid Build Coastguard Worker #include "Constants.hpp"
18*03ce13f7SAndroid Build Coastguard Worker #include "System/Debug.hpp"
19*03ce13f7SAndroid Build Coastguard Worker #include "Vulkan/VkDevice.hpp"
20*03ce13f7SAndroid Build Coastguard Worker #include "Vulkan/VkPipelineLayout.hpp"
21*03ce13f7SAndroid Build Coastguard Worker 
22*03ce13f7SAndroid Build Coastguard Worker #include "marl/defer.h"
23*03ce13f7SAndroid Build Coastguard Worker #include "marl/trace.h"
24*03ce13f7SAndroid Build Coastguard Worker #include "marl/waitgroup.h"
25*03ce13f7SAndroid Build Coastguard Worker 
26*03ce13f7SAndroid Build Coastguard Worker #include <queue>
27*03ce13f7SAndroid Build Coastguard Worker 
28*03ce13f7SAndroid Build Coastguard Worker namespace sw {
29*03ce13f7SAndroid Build Coastguard Worker 
ComputeProgram(vk::Device * device,std::shared_ptr<SpirvShader> shader,const vk::PipelineLayout * pipelineLayout,const vk::DescriptorSet::Bindings & descriptorSets)30*03ce13f7SAndroid Build Coastguard Worker ComputeProgram::ComputeProgram(vk::Device *device, std::shared_ptr<SpirvShader> shader, const vk::PipelineLayout *pipelineLayout, const vk::DescriptorSet::Bindings &descriptorSets)
31*03ce13f7SAndroid Build Coastguard Worker     : device(device)
32*03ce13f7SAndroid Build Coastguard Worker     , shader(shader)
33*03ce13f7SAndroid Build Coastguard Worker     , pipelineLayout(pipelineLayout)
34*03ce13f7SAndroid Build Coastguard Worker     , descriptorSets(descriptorSets)
35*03ce13f7SAndroid Build Coastguard Worker {
36*03ce13f7SAndroid Build Coastguard Worker }
37*03ce13f7SAndroid Build Coastguard Worker 
~ComputeProgram()38*03ce13f7SAndroid Build Coastguard Worker ComputeProgram::~ComputeProgram()
39*03ce13f7SAndroid Build Coastguard Worker {
40*03ce13f7SAndroid Build Coastguard Worker }
41*03ce13f7SAndroid Build Coastguard Worker 
generate()42*03ce13f7SAndroid Build Coastguard Worker void ComputeProgram::generate()
43*03ce13f7SAndroid Build Coastguard Worker {
44*03ce13f7SAndroid Build Coastguard Worker 	MARL_SCOPED_EVENT("ComputeProgram::generate");
45*03ce13f7SAndroid Build Coastguard Worker 
46*03ce13f7SAndroid Build Coastguard Worker 	SpirvRoutine routine(pipelineLayout);
47*03ce13f7SAndroid Build Coastguard Worker 	shader->emitProlog(&routine);
48*03ce13f7SAndroid Build Coastguard Worker 	emit(&routine);
49*03ce13f7SAndroid Build Coastguard Worker 	shader->emitEpilog(&routine);
50*03ce13f7SAndroid Build Coastguard Worker }
51*03ce13f7SAndroid Build Coastguard Worker 
setWorkgroupBuiltins(Pointer<Byte> data,SpirvRoutine * routine,Int workgroupID[3])52*03ce13f7SAndroid Build Coastguard Worker void ComputeProgram::setWorkgroupBuiltins(Pointer<Byte> data, SpirvRoutine *routine, Int workgroupID[3])
53*03ce13f7SAndroid Build Coastguard Worker {
54*03ce13f7SAndroid Build Coastguard Worker 	// TODO(b/146486064): Consider only assigning these to the SpirvRoutine iff they are ever going to be read.
55*03ce13f7SAndroid Build Coastguard Worker 	routine->numWorkgroups = *Pointer<Int4>(data + OFFSET(Data, numWorkgroups));
56*03ce13f7SAndroid Build Coastguard Worker 	routine->workgroupID = Insert(Insert(Insert(Int4(0), workgroupID[0], 0), workgroupID[1], 1), workgroupID[2], 2);
57*03ce13f7SAndroid Build Coastguard Worker 	routine->workgroupSize = *Pointer<Int4>(data + OFFSET(Data, workgroupSize));
58*03ce13f7SAndroid Build Coastguard Worker 	routine->subgroupsPerWorkgroup = *Pointer<Int>(data + OFFSET(Data, subgroupsPerWorkgroup));
59*03ce13f7SAndroid Build Coastguard Worker 	routine->invocationsPerSubgroup = *Pointer<Int>(data + OFFSET(Data, invocationsPerSubgroup));
60*03ce13f7SAndroid Build Coastguard Worker 
61*03ce13f7SAndroid Build Coastguard Worker 	routine->setInputBuiltin(shader.get(), spv::BuiltInNumWorkgroups, [&](const Spirv::BuiltinMapping &builtin, Array<SIMD::Float> &value) {
62*03ce13f7SAndroid Build Coastguard Worker 		value[builtin.FirstComponent + 0] = As<SIMD::Float>(SIMD::Int(routine->numWorkgroups.x));
63*03ce13f7SAndroid Build Coastguard Worker 		value[builtin.FirstComponent + 1] = As<SIMD::Float>(SIMD::Int(routine->numWorkgroups.y));
64*03ce13f7SAndroid Build Coastguard Worker 		value[builtin.FirstComponent + 2] = As<SIMD::Float>(SIMD::Int(routine->numWorkgroups.z));
65*03ce13f7SAndroid Build Coastguard Worker 	});
66*03ce13f7SAndroid Build Coastguard Worker 
67*03ce13f7SAndroid Build Coastguard Worker 	routine->setInputBuiltin(shader.get(), spv::BuiltInWorkgroupId, [&](const Spirv::BuiltinMapping &builtin, Array<SIMD::Float> &value) {
68*03ce13f7SAndroid Build Coastguard Worker 		value[builtin.FirstComponent + 0] = As<SIMD::Float>(SIMD::Int(workgroupID[0]));
69*03ce13f7SAndroid Build Coastguard Worker 		value[builtin.FirstComponent + 1] = As<SIMD::Float>(SIMD::Int(workgroupID[1]));
70*03ce13f7SAndroid Build Coastguard Worker 		value[builtin.FirstComponent + 2] = As<SIMD::Float>(SIMD::Int(workgroupID[2]));
71*03ce13f7SAndroid Build Coastguard Worker 	});
72*03ce13f7SAndroid Build Coastguard Worker 
73*03ce13f7SAndroid Build Coastguard Worker 	routine->setInputBuiltin(shader.get(), spv::BuiltInWorkgroupSize, [&](const Spirv::BuiltinMapping &builtin, Array<SIMD::Float> &value) {
74*03ce13f7SAndroid Build Coastguard Worker 		value[builtin.FirstComponent + 0] = As<SIMD::Float>(SIMD::Int(routine->workgroupSize.x));
75*03ce13f7SAndroid Build Coastguard Worker 		value[builtin.FirstComponent + 1] = As<SIMD::Float>(SIMD::Int(routine->workgroupSize.y));
76*03ce13f7SAndroid Build Coastguard Worker 		value[builtin.FirstComponent + 2] = As<SIMD::Float>(SIMD::Int(routine->workgroupSize.z));
77*03ce13f7SAndroid Build Coastguard Worker 	});
78*03ce13f7SAndroid Build Coastguard Worker 
79*03ce13f7SAndroid Build Coastguard Worker 	routine->setInputBuiltin(shader.get(), spv::BuiltInNumSubgroups, [&](const Spirv::BuiltinMapping &builtin, Array<SIMD::Float> &value) {
80*03ce13f7SAndroid Build Coastguard Worker 		value[builtin.FirstComponent] = As<SIMD::Float>(SIMD::Int(routine->subgroupsPerWorkgroup));
81*03ce13f7SAndroid Build Coastguard Worker 	});
82*03ce13f7SAndroid Build Coastguard Worker 
83*03ce13f7SAndroid Build Coastguard Worker 	routine->setInputBuiltin(shader.get(), spv::BuiltInSubgroupSize, [&](const Spirv::BuiltinMapping &builtin, Array<SIMD::Float> &value) {
84*03ce13f7SAndroid Build Coastguard Worker 		value[builtin.FirstComponent] = As<SIMD::Float>(SIMD::Int(routine->invocationsPerSubgroup));
85*03ce13f7SAndroid Build Coastguard Worker 	});
86*03ce13f7SAndroid Build Coastguard Worker 
87*03ce13f7SAndroid Build Coastguard Worker 	routine->setImmutableInputBuiltins(shader.get());
88*03ce13f7SAndroid Build Coastguard Worker }
89*03ce13f7SAndroid Build Coastguard Worker 
setSubgroupBuiltins(Pointer<Byte> data,SpirvRoutine * routine,Int workgroupID[3],SIMD::Int localInvocationIndex,Int subgroupIndex)90*03ce13f7SAndroid Build Coastguard Worker void ComputeProgram::setSubgroupBuiltins(Pointer<Byte> data, SpirvRoutine *routine, Int workgroupID[3], SIMD::Int localInvocationIndex, Int subgroupIndex)
91*03ce13f7SAndroid Build Coastguard Worker {
92*03ce13f7SAndroid Build Coastguard Worker 	Int4 numWorkgroups = *Pointer<Int4>(data + OFFSET(Data, numWorkgroups));
93*03ce13f7SAndroid Build Coastguard Worker 	Int4 workgroupSize = *Pointer<Int4>(data + OFFSET(Data, workgroupSize));
94*03ce13f7SAndroid Build Coastguard Worker 
95*03ce13f7SAndroid Build Coastguard Worker 	Int workgroupSizeX = workgroupSize.x;
96*03ce13f7SAndroid Build Coastguard Worker 	Int workgroupSizeY = workgroupSize.y;
97*03ce13f7SAndroid Build Coastguard Worker 
98*03ce13f7SAndroid Build Coastguard Worker 	SIMD::Int localInvocationID[3];
99*03ce13f7SAndroid Build Coastguard Worker 	{
100*03ce13f7SAndroid Build Coastguard Worker 		SIMD::Int idx = localInvocationIndex;
101*03ce13f7SAndroid Build Coastguard Worker 		localInvocationID[2] = idx / SIMD::Int(workgroupSizeX * workgroupSizeY);
102*03ce13f7SAndroid Build Coastguard Worker 		idx -= localInvocationID[2] * SIMD::Int(workgroupSizeX * workgroupSizeY);  // modulo
103*03ce13f7SAndroid Build Coastguard Worker 		localInvocationID[1] = idx / SIMD::Int(workgroupSizeX);
104*03ce13f7SAndroid Build Coastguard Worker 		idx -= localInvocationID[1] * SIMD::Int(workgroupSizeX);  // modulo
105*03ce13f7SAndroid Build Coastguard Worker 		localInvocationID[0] = idx;
106*03ce13f7SAndroid Build Coastguard Worker 	}
107*03ce13f7SAndroid Build Coastguard Worker 
108*03ce13f7SAndroid Build Coastguard Worker 	Int4 wgID = Insert(Insert(Insert(Int4(0), workgroupID[0], 0), workgroupID[1], 1), workgroupID[2], 2);
109*03ce13f7SAndroid Build Coastguard Worker 	auto localBase = workgroupSize * wgID;
110*03ce13f7SAndroid Build Coastguard Worker 	SIMD::Int globalInvocationID[3];
111*03ce13f7SAndroid Build Coastguard Worker 	globalInvocationID[0] = SIMD::Int(Extract(localBase, 0)) + localInvocationID[0];
112*03ce13f7SAndroid Build Coastguard Worker 	globalInvocationID[1] = SIMD::Int(Extract(localBase, 1)) + localInvocationID[1];
113*03ce13f7SAndroid Build Coastguard Worker 	globalInvocationID[2] = SIMD::Int(Extract(localBase, 2)) + localInvocationID[2];
114*03ce13f7SAndroid Build Coastguard Worker 
115*03ce13f7SAndroid Build Coastguard Worker 	routine->localInvocationIndex = localInvocationIndex;
116*03ce13f7SAndroid Build Coastguard Worker 	routine->subgroupIndex = subgroupIndex;
117*03ce13f7SAndroid Build Coastguard Worker 	routine->localInvocationID[0] = localInvocationID[0];
118*03ce13f7SAndroid Build Coastguard Worker 	routine->localInvocationID[1] = localInvocationID[1];
119*03ce13f7SAndroid Build Coastguard Worker 	routine->localInvocationID[2] = localInvocationID[2];
120*03ce13f7SAndroid Build Coastguard Worker 	routine->globalInvocationID[0] = globalInvocationID[0];
121*03ce13f7SAndroid Build Coastguard Worker 	routine->globalInvocationID[1] = globalInvocationID[1];
122*03ce13f7SAndroid Build Coastguard Worker 	routine->globalInvocationID[2] = globalInvocationID[2];
123*03ce13f7SAndroid Build Coastguard Worker 
124*03ce13f7SAndroid Build Coastguard Worker 	routine->setInputBuiltin(shader.get(), spv::BuiltInLocalInvocationIndex, [&](const Spirv::BuiltinMapping &builtin, Array<SIMD::Float> &value) {
125*03ce13f7SAndroid Build Coastguard Worker 		ASSERT(builtin.SizeInComponents == 1);
126*03ce13f7SAndroid Build Coastguard Worker 		value[builtin.FirstComponent] = As<SIMD::Float>(localInvocationIndex);
127*03ce13f7SAndroid Build Coastguard Worker 	});
128*03ce13f7SAndroid Build Coastguard Worker 
129*03ce13f7SAndroid Build Coastguard Worker 	routine->setInputBuiltin(shader.get(), spv::BuiltInSubgroupId, [&](const Spirv::BuiltinMapping &builtin, Array<SIMD::Float> &value) {
130*03ce13f7SAndroid Build Coastguard Worker 		ASSERT(builtin.SizeInComponents == 1);
131*03ce13f7SAndroid Build Coastguard Worker 		value[builtin.FirstComponent] = As<SIMD::Float>(SIMD::Int(subgroupIndex));
132*03ce13f7SAndroid Build Coastguard Worker 	});
133*03ce13f7SAndroid Build Coastguard Worker 
134*03ce13f7SAndroid Build Coastguard Worker 	routine->setInputBuiltin(shader.get(), spv::BuiltInLocalInvocationId, [&](const Spirv::BuiltinMapping &builtin, Array<SIMD::Float> &value) {
135*03ce13f7SAndroid Build Coastguard Worker 		for(uint32_t component = 0; component < builtin.SizeInComponents; component++)
136*03ce13f7SAndroid Build Coastguard Worker 		{
137*03ce13f7SAndroid Build Coastguard Worker 			value[builtin.FirstComponent + component] =
138*03ce13f7SAndroid Build Coastguard Worker 			    As<SIMD::Float>(localInvocationID[component]);
139*03ce13f7SAndroid Build Coastguard Worker 		}
140*03ce13f7SAndroid Build Coastguard Worker 	});
141*03ce13f7SAndroid Build Coastguard Worker 
142*03ce13f7SAndroid Build Coastguard Worker 	routine->setInputBuiltin(shader.get(), spv::BuiltInGlobalInvocationId, [&](const Spirv::BuiltinMapping &builtin, Array<SIMD::Float> &value) {
143*03ce13f7SAndroid Build Coastguard Worker 		for(uint32_t component = 0; component < builtin.SizeInComponents; component++)
144*03ce13f7SAndroid Build Coastguard Worker 		{
145*03ce13f7SAndroid Build Coastguard Worker 			value[builtin.FirstComponent + component] =
146*03ce13f7SAndroid Build Coastguard Worker 			    As<SIMD::Float>(globalInvocationID[component]);
147*03ce13f7SAndroid Build Coastguard Worker 		}
148*03ce13f7SAndroid Build Coastguard Worker 	});
149*03ce13f7SAndroid Build Coastguard Worker }
150*03ce13f7SAndroid Build Coastguard Worker 
emit(SpirvRoutine * routine)151*03ce13f7SAndroid Build Coastguard Worker void ComputeProgram::emit(SpirvRoutine *routine)
152*03ce13f7SAndroid Build Coastguard Worker {
153*03ce13f7SAndroid Build Coastguard Worker 	Pointer<Byte> device = Arg<0>();
154*03ce13f7SAndroid Build Coastguard Worker 	Pointer<Byte> data = Arg<1>();
155*03ce13f7SAndroid Build Coastguard Worker 	Int workgroupX = Arg<2>();
156*03ce13f7SAndroid Build Coastguard Worker 	Int workgroupY = Arg<3>();
157*03ce13f7SAndroid Build Coastguard Worker 	Int workgroupZ = Arg<4>();
158*03ce13f7SAndroid Build Coastguard Worker 	Pointer<Byte> workgroupMemory = Arg<5>();
159*03ce13f7SAndroid Build Coastguard Worker 	Int firstSubgroup = Arg<6>();
160*03ce13f7SAndroid Build Coastguard Worker 	Int subgroupCount = Arg<7>();
161*03ce13f7SAndroid Build Coastguard Worker 
162*03ce13f7SAndroid Build Coastguard Worker 	routine->device = device;
163*03ce13f7SAndroid Build Coastguard Worker 	routine->descriptorSets = data + OFFSET(Data, descriptorSets);
164*03ce13f7SAndroid Build Coastguard Worker 	routine->descriptorDynamicOffsets = data + OFFSET(Data, descriptorDynamicOffsets);
165*03ce13f7SAndroid Build Coastguard Worker 	routine->pushConstants = data + OFFSET(Data, pushConstants);
166*03ce13f7SAndroid Build Coastguard Worker 	routine->constants = device + OFFSET(vk::Device, constants);
167*03ce13f7SAndroid Build Coastguard Worker 	routine->workgroupMemory = workgroupMemory;
168*03ce13f7SAndroid Build Coastguard Worker 
169*03ce13f7SAndroid Build Coastguard Worker 	Int invocationsPerWorkgroup = *Pointer<Int>(data + OFFSET(Data, invocationsPerWorkgroup));
170*03ce13f7SAndroid Build Coastguard Worker 
171*03ce13f7SAndroid Build Coastguard Worker 	Int workgroupID[3] = { workgroupX, workgroupY, workgroupZ };
172*03ce13f7SAndroid Build Coastguard Worker 	setWorkgroupBuiltins(data, routine, workgroupID);
173*03ce13f7SAndroid Build Coastguard Worker 
174*03ce13f7SAndroid Build Coastguard Worker 	For(Int i = 0, i < subgroupCount, i++)
175*03ce13f7SAndroid Build Coastguard Worker 	{
176*03ce13f7SAndroid Build Coastguard Worker 		auto subgroupIndex = firstSubgroup + i;
177*03ce13f7SAndroid Build Coastguard Worker 
178*03ce13f7SAndroid Build Coastguard Worker 		// TODO: Replace SIMD::Int(0, 1, 2, 3) with SIMD-width equivalent
179*03ce13f7SAndroid Build Coastguard Worker 		auto localInvocationIndex = SIMD::Int(subgroupIndex * SIMD::Width) + SIMD::Int(0, 1, 2, 3);
180*03ce13f7SAndroid Build Coastguard Worker 
181*03ce13f7SAndroid Build Coastguard Worker 		// Disable lanes where (invocationIDs >= invocationsPerWorkgroup)
182*03ce13f7SAndroid Build Coastguard Worker 		auto activeLaneMask = CmpLT(localInvocationIndex, SIMD::Int(invocationsPerWorkgroup));
183*03ce13f7SAndroid Build Coastguard Worker 
184*03ce13f7SAndroid Build Coastguard Worker 		setSubgroupBuiltins(data, routine, workgroupID, localInvocationIndex, subgroupIndex);
185*03ce13f7SAndroid Build Coastguard Worker 
186*03ce13f7SAndroid Build Coastguard Worker 		shader->emit(routine, activeLaneMask, activeLaneMask, descriptorSets);
187*03ce13f7SAndroid Build Coastguard Worker 	}
188*03ce13f7SAndroid Build Coastguard Worker }
189*03ce13f7SAndroid Build Coastguard Worker 
run(const vk::DescriptorSet::Array & descriptorSetObjects,const vk::DescriptorSet::Bindings & descriptorSets,const vk::DescriptorSet::DynamicOffsets & descriptorDynamicOffsets,const vk::Pipeline::PushConstantStorage & pushConstants,uint32_t baseGroupX,uint32_t baseGroupY,uint32_t baseGroupZ,uint32_t groupCountX,uint32_t groupCountY,uint32_t groupCountZ)190*03ce13f7SAndroid Build Coastguard Worker void ComputeProgram::run(
191*03ce13f7SAndroid Build Coastguard Worker     const vk::DescriptorSet::Array &descriptorSetObjects,
192*03ce13f7SAndroid Build Coastguard Worker     const vk::DescriptorSet::Bindings &descriptorSets,
193*03ce13f7SAndroid Build Coastguard Worker     const vk::DescriptorSet::DynamicOffsets &descriptorDynamicOffsets,
194*03ce13f7SAndroid Build Coastguard Worker     const vk::Pipeline::PushConstantStorage &pushConstants,
195*03ce13f7SAndroid Build Coastguard Worker     uint32_t baseGroupX, uint32_t baseGroupY, uint32_t baseGroupZ,
196*03ce13f7SAndroid Build Coastguard Worker     uint32_t groupCountX, uint32_t groupCountY, uint32_t groupCountZ)
197*03ce13f7SAndroid Build Coastguard Worker {
198*03ce13f7SAndroid Build Coastguard Worker 	uint32_t workgroupSizeX = shader->getWorkgroupSizeX();
199*03ce13f7SAndroid Build Coastguard Worker 	uint32_t workgroupSizeY = shader->getWorkgroupSizeY();
200*03ce13f7SAndroid Build Coastguard Worker 	uint32_t workgroupSizeZ = shader->getWorkgroupSizeZ();
201*03ce13f7SAndroid Build Coastguard Worker 
202*03ce13f7SAndroid Build Coastguard Worker 	auto invocationsPerSubgroup = SIMD::Width;
203*03ce13f7SAndroid Build Coastguard Worker 	auto invocationsPerWorkgroup = workgroupSizeX * workgroupSizeY * workgroupSizeZ;
204*03ce13f7SAndroid Build Coastguard Worker 	auto subgroupsPerWorkgroup = (invocationsPerWorkgroup + invocationsPerSubgroup - 1) / invocationsPerSubgroup;
205*03ce13f7SAndroid Build Coastguard Worker 
206*03ce13f7SAndroid Build Coastguard Worker 	Data data;
207*03ce13f7SAndroid Build Coastguard Worker 	data.descriptorSets = descriptorSets;
208*03ce13f7SAndroid Build Coastguard Worker 	data.descriptorDynamicOffsets = descriptorDynamicOffsets;
209*03ce13f7SAndroid Build Coastguard Worker 	data.numWorkgroups[0] = groupCountX;
210*03ce13f7SAndroid Build Coastguard Worker 	data.numWorkgroups[1] = groupCountY;
211*03ce13f7SAndroid Build Coastguard Worker 	data.numWorkgroups[2] = groupCountZ;
212*03ce13f7SAndroid Build Coastguard Worker 	data.workgroupSize[0] = workgroupSizeX;
213*03ce13f7SAndroid Build Coastguard Worker 	data.workgroupSize[1] = workgroupSizeY;
214*03ce13f7SAndroid Build Coastguard Worker 	data.workgroupSize[2] = workgroupSizeZ;
215*03ce13f7SAndroid Build Coastguard Worker 	data.invocationsPerSubgroup = invocationsPerSubgroup;
216*03ce13f7SAndroid Build Coastguard Worker 	data.invocationsPerWorkgroup = invocationsPerWorkgroup;
217*03ce13f7SAndroid Build Coastguard Worker 	data.subgroupsPerWorkgroup = subgroupsPerWorkgroup;
218*03ce13f7SAndroid Build Coastguard Worker 	data.pushConstants = pushConstants;
219*03ce13f7SAndroid Build Coastguard Worker 
220*03ce13f7SAndroid Build Coastguard Worker 	marl::WaitGroup wg;
221*03ce13f7SAndroid Build Coastguard Worker 	constexpr uint32_t batchCount = 16;
222*03ce13f7SAndroid Build Coastguard Worker 
223*03ce13f7SAndroid Build Coastguard Worker 	auto groupCount = groupCountX * groupCountY * groupCountZ;
224*03ce13f7SAndroid Build Coastguard Worker 
225*03ce13f7SAndroid Build Coastguard Worker 	for(uint32_t batchID = 0; batchID < batchCount && batchID < groupCount; batchID++)
226*03ce13f7SAndroid Build Coastguard Worker 	{
227*03ce13f7SAndroid Build Coastguard Worker 		wg.add(1);
228*03ce13f7SAndroid Build Coastguard Worker 		marl::schedule([this, batchID, groupCount, groupCountX, groupCountY,
229*03ce13f7SAndroid Build Coastguard Worker 		                baseGroupZ, baseGroupY, baseGroupX, wg, subgroupsPerWorkgroup,
230*03ce13f7SAndroid Build Coastguard Worker 		                &data] {
231*03ce13f7SAndroid Build Coastguard Worker 			// Workaround for the fact that some compilers don't allow batchCount to be captured.
232*03ce13f7SAndroid Build Coastguard Worker 			constexpr uint32_t batchCount = 16;
233*03ce13f7SAndroid Build Coastguard Worker 			defer(wg.done());
234*03ce13f7SAndroid Build Coastguard Worker 			std::vector<uint8_t> workgroupMemory(shader->workgroupMemory.size());
235*03ce13f7SAndroid Build Coastguard Worker 
236*03ce13f7SAndroid Build Coastguard Worker 			for(uint32_t groupIndex = batchID; groupIndex < groupCount; groupIndex += batchCount)
237*03ce13f7SAndroid Build Coastguard Worker 			{
238*03ce13f7SAndroid Build Coastguard Worker 				auto modulo = groupIndex;
239*03ce13f7SAndroid Build Coastguard Worker 				auto groupOffsetZ = modulo / (groupCountX * groupCountY);
240*03ce13f7SAndroid Build Coastguard Worker 				modulo -= groupOffsetZ * (groupCountX * groupCountY);
241*03ce13f7SAndroid Build Coastguard Worker 				auto groupOffsetY = modulo / groupCountX;
242*03ce13f7SAndroid Build Coastguard Worker 				modulo -= groupOffsetY * groupCountX;
243*03ce13f7SAndroid Build Coastguard Worker 				auto groupOffsetX = modulo;
244*03ce13f7SAndroid Build Coastguard Worker 
245*03ce13f7SAndroid Build Coastguard Worker 				auto groupZ = baseGroupZ + groupOffsetZ;
246*03ce13f7SAndroid Build Coastguard Worker 				auto groupY = baseGroupY + groupOffsetY;
247*03ce13f7SAndroid Build Coastguard Worker 				auto groupX = baseGroupX + groupOffsetX;
248*03ce13f7SAndroid Build Coastguard Worker 				MARL_SCOPED_EVENT("groupX: %d, groupY: %d, groupZ: %d", groupX, groupY, groupZ);
249*03ce13f7SAndroid Build Coastguard Worker 
250*03ce13f7SAndroid Build Coastguard Worker 				using Coroutine = std::unique_ptr<rr::Stream<SpirvEmitter::YieldResult>>;
251*03ce13f7SAndroid Build Coastguard Worker 				std::queue<Coroutine> coroutines;
252*03ce13f7SAndroid Build Coastguard Worker 
253*03ce13f7SAndroid Build Coastguard Worker 				if(shader->getAnalysis().ContainsControlBarriers)
254*03ce13f7SAndroid Build Coastguard Worker 				{
255*03ce13f7SAndroid Build Coastguard Worker 					// Make a function call per subgroup so each subgroup
256*03ce13f7SAndroid Build Coastguard Worker 					// can yield, bringing all subgroups to the barrier
257*03ce13f7SAndroid Build Coastguard Worker 					// together.
258*03ce13f7SAndroid Build Coastguard Worker 					for(uint32_t subgroupIndex = 0; subgroupIndex < subgroupsPerWorkgroup; subgroupIndex++)
259*03ce13f7SAndroid Build Coastguard Worker 					{
260*03ce13f7SAndroid Build Coastguard Worker 						auto coroutine = (*this)(device, &data, groupX, groupY, groupZ, workgroupMemory.data(), subgroupIndex, 1);
261*03ce13f7SAndroid Build Coastguard Worker 						coroutines.push(std::move(coroutine));
262*03ce13f7SAndroid Build Coastguard Worker 					}
263*03ce13f7SAndroid Build Coastguard Worker 				}
264*03ce13f7SAndroid Build Coastguard Worker 				else
265*03ce13f7SAndroid Build Coastguard Worker 				{
266*03ce13f7SAndroid Build Coastguard Worker 					auto coroutine = (*this)(device, &data, groupX, groupY, groupZ, workgroupMemory.data(), 0, subgroupsPerWorkgroup);
267*03ce13f7SAndroid Build Coastguard Worker 					coroutines.push(std::move(coroutine));
268*03ce13f7SAndroid Build Coastguard Worker 				}
269*03ce13f7SAndroid Build Coastguard Worker 
270*03ce13f7SAndroid Build Coastguard Worker 				while(coroutines.size() > 0)
271*03ce13f7SAndroid Build Coastguard Worker 				{
272*03ce13f7SAndroid Build Coastguard Worker 					auto coroutine = std::move(coroutines.front());
273*03ce13f7SAndroid Build Coastguard Worker 					coroutines.pop();
274*03ce13f7SAndroid Build Coastguard Worker 
275*03ce13f7SAndroid Build Coastguard Worker 					SpirvEmitter::YieldResult result;
276*03ce13f7SAndroid Build Coastguard Worker 					if(coroutine->await(result))
277*03ce13f7SAndroid Build Coastguard Worker 					{
278*03ce13f7SAndroid Build Coastguard Worker 						// TODO: Consider result (when the enum is more than 1 entry).
279*03ce13f7SAndroid Build Coastguard Worker 						coroutines.push(std::move(coroutine));
280*03ce13f7SAndroid Build Coastguard Worker 					}
281*03ce13f7SAndroid Build Coastguard Worker 				}
282*03ce13f7SAndroid Build Coastguard Worker 			}
283*03ce13f7SAndroid Build Coastguard Worker 		});
284*03ce13f7SAndroid Build Coastguard Worker 	}
285*03ce13f7SAndroid Build Coastguard Worker 
286*03ce13f7SAndroid Build Coastguard Worker 	wg.wait();
287*03ce13f7SAndroid Build Coastguard Worker 
288*03ce13f7SAndroid Build Coastguard Worker 	if(shader->containsImageWrite())
289*03ce13f7SAndroid Build Coastguard Worker 	{
290*03ce13f7SAndroid Build Coastguard Worker 		vk::DescriptorSet::ContentsChanged(descriptorSetObjects, pipelineLayout, device);
291*03ce13f7SAndroid Build Coastguard Worker 	}
292*03ce13f7SAndroid Build Coastguard Worker }
293*03ce13f7SAndroid Build Coastguard Worker 
294*03ce13f7SAndroid Build Coastguard Worker }  // namespace sw
295