1*03ce13f7SAndroid Build Coastguard Worker // Copyright 2019 The SwiftShader Authors. All Rights Reserved.
2*03ce13f7SAndroid Build Coastguard Worker //
3*03ce13f7SAndroid Build Coastguard Worker // Licensed under the Apache License, Version 2.0 (the "License");
4*03ce13f7SAndroid Build Coastguard Worker // you may not use this file except in compliance with the License.
5*03ce13f7SAndroid Build Coastguard Worker // You may obtain a copy of the License at
6*03ce13f7SAndroid Build Coastguard Worker //
7*03ce13f7SAndroid Build Coastguard Worker // http://www.apache.org/licenses/LICENSE-2.0
8*03ce13f7SAndroid Build Coastguard Worker //
9*03ce13f7SAndroid Build Coastguard Worker // Unless required by applicable law or agreed to in writing, software
10*03ce13f7SAndroid Build Coastguard Worker // distributed under the License is distributed on an "AS IS" BASIS,
11*03ce13f7SAndroid Build Coastguard Worker // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*03ce13f7SAndroid Build Coastguard Worker // See the License for the specific language governing permissions and
13*03ce13f7SAndroid Build Coastguard Worker // limitations under the License.
14*03ce13f7SAndroid Build Coastguard Worker
15*03ce13f7SAndroid Build Coastguard Worker #include "ComputeProgram.hpp"
16*03ce13f7SAndroid Build Coastguard Worker
17*03ce13f7SAndroid Build Coastguard Worker #include "Constants.hpp"
18*03ce13f7SAndroid Build Coastguard Worker #include "System/Debug.hpp"
19*03ce13f7SAndroid Build Coastguard Worker #include "Vulkan/VkDevice.hpp"
20*03ce13f7SAndroid Build Coastguard Worker #include "Vulkan/VkPipelineLayout.hpp"
21*03ce13f7SAndroid Build Coastguard Worker
22*03ce13f7SAndroid Build Coastguard Worker #include "marl/defer.h"
23*03ce13f7SAndroid Build Coastguard Worker #include "marl/trace.h"
24*03ce13f7SAndroid Build Coastguard Worker #include "marl/waitgroup.h"
25*03ce13f7SAndroid Build Coastguard Worker
26*03ce13f7SAndroid Build Coastguard Worker #include <queue>
27*03ce13f7SAndroid Build Coastguard Worker
28*03ce13f7SAndroid Build Coastguard Worker namespace sw {
29*03ce13f7SAndroid Build Coastguard Worker
ComputeProgram(vk::Device * device,std::shared_ptr<SpirvShader> shader,const vk::PipelineLayout * pipelineLayout,const vk::DescriptorSet::Bindings & descriptorSets)30*03ce13f7SAndroid Build Coastguard Worker ComputeProgram::ComputeProgram(vk::Device *device, std::shared_ptr<SpirvShader> shader, const vk::PipelineLayout *pipelineLayout, const vk::DescriptorSet::Bindings &descriptorSets)
31*03ce13f7SAndroid Build Coastguard Worker : device(device)
32*03ce13f7SAndroid Build Coastguard Worker , shader(shader)
33*03ce13f7SAndroid Build Coastguard Worker , pipelineLayout(pipelineLayout)
34*03ce13f7SAndroid Build Coastguard Worker , descriptorSets(descriptorSets)
35*03ce13f7SAndroid Build Coastguard Worker {
36*03ce13f7SAndroid Build Coastguard Worker }
37*03ce13f7SAndroid Build Coastguard Worker
~ComputeProgram()38*03ce13f7SAndroid Build Coastguard Worker ComputeProgram::~ComputeProgram()
39*03ce13f7SAndroid Build Coastguard Worker {
40*03ce13f7SAndroid Build Coastguard Worker }
41*03ce13f7SAndroid Build Coastguard Worker
generate()42*03ce13f7SAndroid Build Coastguard Worker void ComputeProgram::generate()
43*03ce13f7SAndroid Build Coastguard Worker {
44*03ce13f7SAndroid Build Coastguard Worker MARL_SCOPED_EVENT("ComputeProgram::generate");
45*03ce13f7SAndroid Build Coastguard Worker
46*03ce13f7SAndroid Build Coastguard Worker SpirvRoutine routine(pipelineLayout);
47*03ce13f7SAndroid Build Coastguard Worker shader->emitProlog(&routine);
48*03ce13f7SAndroid Build Coastguard Worker emit(&routine);
49*03ce13f7SAndroid Build Coastguard Worker shader->emitEpilog(&routine);
50*03ce13f7SAndroid Build Coastguard Worker }
51*03ce13f7SAndroid Build Coastguard Worker
setWorkgroupBuiltins(Pointer<Byte> data,SpirvRoutine * routine,Int workgroupID[3])52*03ce13f7SAndroid Build Coastguard Worker void ComputeProgram::setWorkgroupBuiltins(Pointer<Byte> data, SpirvRoutine *routine, Int workgroupID[3])
53*03ce13f7SAndroid Build Coastguard Worker {
54*03ce13f7SAndroid Build Coastguard Worker // TODO(b/146486064): Consider only assigning these to the SpirvRoutine iff they are ever going to be read.
55*03ce13f7SAndroid Build Coastguard Worker routine->numWorkgroups = *Pointer<Int4>(data + OFFSET(Data, numWorkgroups));
56*03ce13f7SAndroid Build Coastguard Worker routine->workgroupID = Insert(Insert(Insert(Int4(0), workgroupID[0], 0), workgroupID[1], 1), workgroupID[2], 2);
57*03ce13f7SAndroid Build Coastguard Worker routine->workgroupSize = *Pointer<Int4>(data + OFFSET(Data, workgroupSize));
58*03ce13f7SAndroid Build Coastguard Worker routine->subgroupsPerWorkgroup = *Pointer<Int>(data + OFFSET(Data, subgroupsPerWorkgroup));
59*03ce13f7SAndroid Build Coastguard Worker routine->invocationsPerSubgroup = *Pointer<Int>(data + OFFSET(Data, invocationsPerSubgroup));
60*03ce13f7SAndroid Build Coastguard Worker
61*03ce13f7SAndroid Build Coastguard Worker routine->setInputBuiltin(shader.get(), spv::BuiltInNumWorkgroups, [&](const Spirv::BuiltinMapping &builtin, Array<SIMD::Float> &value) {
62*03ce13f7SAndroid Build Coastguard Worker value[builtin.FirstComponent + 0] = As<SIMD::Float>(SIMD::Int(routine->numWorkgroups.x));
63*03ce13f7SAndroid Build Coastguard Worker value[builtin.FirstComponent + 1] = As<SIMD::Float>(SIMD::Int(routine->numWorkgroups.y));
64*03ce13f7SAndroid Build Coastguard Worker value[builtin.FirstComponent + 2] = As<SIMD::Float>(SIMD::Int(routine->numWorkgroups.z));
65*03ce13f7SAndroid Build Coastguard Worker });
66*03ce13f7SAndroid Build Coastguard Worker
67*03ce13f7SAndroid Build Coastguard Worker routine->setInputBuiltin(shader.get(), spv::BuiltInWorkgroupId, [&](const Spirv::BuiltinMapping &builtin, Array<SIMD::Float> &value) {
68*03ce13f7SAndroid Build Coastguard Worker value[builtin.FirstComponent + 0] = As<SIMD::Float>(SIMD::Int(workgroupID[0]));
69*03ce13f7SAndroid Build Coastguard Worker value[builtin.FirstComponent + 1] = As<SIMD::Float>(SIMD::Int(workgroupID[1]));
70*03ce13f7SAndroid Build Coastguard Worker value[builtin.FirstComponent + 2] = As<SIMD::Float>(SIMD::Int(workgroupID[2]));
71*03ce13f7SAndroid Build Coastguard Worker });
72*03ce13f7SAndroid Build Coastguard Worker
73*03ce13f7SAndroid Build Coastguard Worker routine->setInputBuiltin(shader.get(), spv::BuiltInWorkgroupSize, [&](const Spirv::BuiltinMapping &builtin, Array<SIMD::Float> &value) {
74*03ce13f7SAndroid Build Coastguard Worker value[builtin.FirstComponent + 0] = As<SIMD::Float>(SIMD::Int(routine->workgroupSize.x));
75*03ce13f7SAndroid Build Coastguard Worker value[builtin.FirstComponent + 1] = As<SIMD::Float>(SIMD::Int(routine->workgroupSize.y));
76*03ce13f7SAndroid Build Coastguard Worker value[builtin.FirstComponent + 2] = As<SIMD::Float>(SIMD::Int(routine->workgroupSize.z));
77*03ce13f7SAndroid Build Coastguard Worker });
78*03ce13f7SAndroid Build Coastguard Worker
79*03ce13f7SAndroid Build Coastguard Worker routine->setInputBuiltin(shader.get(), spv::BuiltInNumSubgroups, [&](const Spirv::BuiltinMapping &builtin, Array<SIMD::Float> &value) {
80*03ce13f7SAndroid Build Coastguard Worker value[builtin.FirstComponent] = As<SIMD::Float>(SIMD::Int(routine->subgroupsPerWorkgroup));
81*03ce13f7SAndroid Build Coastguard Worker });
82*03ce13f7SAndroid Build Coastguard Worker
83*03ce13f7SAndroid Build Coastguard Worker routine->setInputBuiltin(shader.get(), spv::BuiltInSubgroupSize, [&](const Spirv::BuiltinMapping &builtin, Array<SIMD::Float> &value) {
84*03ce13f7SAndroid Build Coastguard Worker value[builtin.FirstComponent] = As<SIMD::Float>(SIMD::Int(routine->invocationsPerSubgroup));
85*03ce13f7SAndroid Build Coastguard Worker });
86*03ce13f7SAndroid Build Coastguard Worker
87*03ce13f7SAndroid Build Coastguard Worker routine->setImmutableInputBuiltins(shader.get());
88*03ce13f7SAndroid Build Coastguard Worker }
89*03ce13f7SAndroid Build Coastguard Worker
setSubgroupBuiltins(Pointer<Byte> data,SpirvRoutine * routine,Int workgroupID[3],SIMD::Int localInvocationIndex,Int subgroupIndex)90*03ce13f7SAndroid Build Coastguard Worker void ComputeProgram::setSubgroupBuiltins(Pointer<Byte> data, SpirvRoutine *routine, Int workgroupID[3], SIMD::Int localInvocationIndex, Int subgroupIndex)
91*03ce13f7SAndroid Build Coastguard Worker {
92*03ce13f7SAndroid Build Coastguard Worker Int4 numWorkgroups = *Pointer<Int4>(data + OFFSET(Data, numWorkgroups));
93*03ce13f7SAndroid Build Coastguard Worker Int4 workgroupSize = *Pointer<Int4>(data + OFFSET(Data, workgroupSize));
94*03ce13f7SAndroid Build Coastguard Worker
95*03ce13f7SAndroid Build Coastguard Worker Int workgroupSizeX = workgroupSize.x;
96*03ce13f7SAndroid Build Coastguard Worker Int workgroupSizeY = workgroupSize.y;
97*03ce13f7SAndroid Build Coastguard Worker
98*03ce13f7SAndroid Build Coastguard Worker SIMD::Int localInvocationID[3];
99*03ce13f7SAndroid Build Coastguard Worker {
100*03ce13f7SAndroid Build Coastguard Worker SIMD::Int idx = localInvocationIndex;
101*03ce13f7SAndroid Build Coastguard Worker localInvocationID[2] = idx / SIMD::Int(workgroupSizeX * workgroupSizeY);
102*03ce13f7SAndroid Build Coastguard Worker idx -= localInvocationID[2] * SIMD::Int(workgroupSizeX * workgroupSizeY); // modulo
103*03ce13f7SAndroid Build Coastguard Worker localInvocationID[1] = idx / SIMD::Int(workgroupSizeX);
104*03ce13f7SAndroid Build Coastguard Worker idx -= localInvocationID[1] * SIMD::Int(workgroupSizeX); // modulo
105*03ce13f7SAndroid Build Coastguard Worker localInvocationID[0] = idx;
106*03ce13f7SAndroid Build Coastguard Worker }
107*03ce13f7SAndroid Build Coastguard Worker
108*03ce13f7SAndroid Build Coastguard Worker Int4 wgID = Insert(Insert(Insert(Int4(0), workgroupID[0], 0), workgroupID[1], 1), workgroupID[2], 2);
109*03ce13f7SAndroid Build Coastguard Worker auto localBase = workgroupSize * wgID;
110*03ce13f7SAndroid Build Coastguard Worker SIMD::Int globalInvocationID[3];
111*03ce13f7SAndroid Build Coastguard Worker globalInvocationID[0] = SIMD::Int(Extract(localBase, 0)) + localInvocationID[0];
112*03ce13f7SAndroid Build Coastguard Worker globalInvocationID[1] = SIMD::Int(Extract(localBase, 1)) + localInvocationID[1];
113*03ce13f7SAndroid Build Coastguard Worker globalInvocationID[2] = SIMD::Int(Extract(localBase, 2)) + localInvocationID[2];
114*03ce13f7SAndroid Build Coastguard Worker
115*03ce13f7SAndroid Build Coastguard Worker routine->localInvocationIndex = localInvocationIndex;
116*03ce13f7SAndroid Build Coastguard Worker routine->subgroupIndex = subgroupIndex;
117*03ce13f7SAndroid Build Coastguard Worker routine->localInvocationID[0] = localInvocationID[0];
118*03ce13f7SAndroid Build Coastguard Worker routine->localInvocationID[1] = localInvocationID[1];
119*03ce13f7SAndroid Build Coastguard Worker routine->localInvocationID[2] = localInvocationID[2];
120*03ce13f7SAndroid Build Coastguard Worker routine->globalInvocationID[0] = globalInvocationID[0];
121*03ce13f7SAndroid Build Coastguard Worker routine->globalInvocationID[1] = globalInvocationID[1];
122*03ce13f7SAndroid Build Coastguard Worker routine->globalInvocationID[2] = globalInvocationID[2];
123*03ce13f7SAndroid Build Coastguard Worker
124*03ce13f7SAndroid Build Coastguard Worker routine->setInputBuiltin(shader.get(), spv::BuiltInLocalInvocationIndex, [&](const Spirv::BuiltinMapping &builtin, Array<SIMD::Float> &value) {
125*03ce13f7SAndroid Build Coastguard Worker ASSERT(builtin.SizeInComponents == 1);
126*03ce13f7SAndroid Build Coastguard Worker value[builtin.FirstComponent] = As<SIMD::Float>(localInvocationIndex);
127*03ce13f7SAndroid Build Coastguard Worker });
128*03ce13f7SAndroid Build Coastguard Worker
129*03ce13f7SAndroid Build Coastguard Worker routine->setInputBuiltin(shader.get(), spv::BuiltInSubgroupId, [&](const Spirv::BuiltinMapping &builtin, Array<SIMD::Float> &value) {
130*03ce13f7SAndroid Build Coastguard Worker ASSERT(builtin.SizeInComponents == 1);
131*03ce13f7SAndroid Build Coastguard Worker value[builtin.FirstComponent] = As<SIMD::Float>(SIMD::Int(subgroupIndex));
132*03ce13f7SAndroid Build Coastguard Worker });
133*03ce13f7SAndroid Build Coastguard Worker
134*03ce13f7SAndroid Build Coastguard Worker routine->setInputBuiltin(shader.get(), spv::BuiltInLocalInvocationId, [&](const Spirv::BuiltinMapping &builtin, Array<SIMD::Float> &value) {
135*03ce13f7SAndroid Build Coastguard Worker for(uint32_t component = 0; component < builtin.SizeInComponents; component++)
136*03ce13f7SAndroid Build Coastguard Worker {
137*03ce13f7SAndroid Build Coastguard Worker value[builtin.FirstComponent + component] =
138*03ce13f7SAndroid Build Coastguard Worker As<SIMD::Float>(localInvocationID[component]);
139*03ce13f7SAndroid Build Coastguard Worker }
140*03ce13f7SAndroid Build Coastguard Worker });
141*03ce13f7SAndroid Build Coastguard Worker
142*03ce13f7SAndroid Build Coastguard Worker routine->setInputBuiltin(shader.get(), spv::BuiltInGlobalInvocationId, [&](const Spirv::BuiltinMapping &builtin, Array<SIMD::Float> &value) {
143*03ce13f7SAndroid Build Coastguard Worker for(uint32_t component = 0; component < builtin.SizeInComponents; component++)
144*03ce13f7SAndroid Build Coastguard Worker {
145*03ce13f7SAndroid Build Coastguard Worker value[builtin.FirstComponent + component] =
146*03ce13f7SAndroid Build Coastguard Worker As<SIMD::Float>(globalInvocationID[component]);
147*03ce13f7SAndroid Build Coastguard Worker }
148*03ce13f7SAndroid Build Coastguard Worker });
149*03ce13f7SAndroid Build Coastguard Worker }
150*03ce13f7SAndroid Build Coastguard Worker
emit(SpirvRoutine * routine)151*03ce13f7SAndroid Build Coastguard Worker void ComputeProgram::emit(SpirvRoutine *routine)
152*03ce13f7SAndroid Build Coastguard Worker {
153*03ce13f7SAndroid Build Coastguard Worker Pointer<Byte> device = Arg<0>();
154*03ce13f7SAndroid Build Coastguard Worker Pointer<Byte> data = Arg<1>();
155*03ce13f7SAndroid Build Coastguard Worker Int workgroupX = Arg<2>();
156*03ce13f7SAndroid Build Coastguard Worker Int workgroupY = Arg<3>();
157*03ce13f7SAndroid Build Coastguard Worker Int workgroupZ = Arg<4>();
158*03ce13f7SAndroid Build Coastguard Worker Pointer<Byte> workgroupMemory = Arg<5>();
159*03ce13f7SAndroid Build Coastguard Worker Int firstSubgroup = Arg<6>();
160*03ce13f7SAndroid Build Coastguard Worker Int subgroupCount = Arg<7>();
161*03ce13f7SAndroid Build Coastguard Worker
162*03ce13f7SAndroid Build Coastguard Worker routine->device = device;
163*03ce13f7SAndroid Build Coastguard Worker routine->descriptorSets = data + OFFSET(Data, descriptorSets);
164*03ce13f7SAndroid Build Coastguard Worker routine->descriptorDynamicOffsets = data + OFFSET(Data, descriptorDynamicOffsets);
165*03ce13f7SAndroid Build Coastguard Worker routine->pushConstants = data + OFFSET(Data, pushConstants);
166*03ce13f7SAndroid Build Coastguard Worker routine->constants = device + OFFSET(vk::Device, constants);
167*03ce13f7SAndroid Build Coastguard Worker routine->workgroupMemory = workgroupMemory;
168*03ce13f7SAndroid Build Coastguard Worker
169*03ce13f7SAndroid Build Coastguard Worker Int invocationsPerWorkgroup = *Pointer<Int>(data + OFFSET(Data, invocationsPerWorkgroup));
170*03ce13f7SAndroid Build Coastguard Worker
171*03ce13f7SAndroid Build Coastguard Worker Int workgroupID[3] = { workgroupX, workgroupY, workgroupZ };
172*03ce13f7SAndroid Build Coastguard Worker setWorkgroupBuiltins(data, routine, workgroupID);
173*03ce13f7SAndroid Build Coastguard Worker
174*03ce13f7SAndroid Build Coastguard Worker For(Int i = 0, i < subgroupCount, i++)
175*03ce13f7SAndroid Build Coastguard Worker {
176*03ce13f7SAndroid Build Coastguard Worker auto subgroupIndex = firstSubgroup + i;
177*03ce13f7SAndroid Build Coastguard Worker
178*03ce13f7SAndroid Build Coastguard Worker // TODO: Replace SIMD::Int(0, 1, 2, 3) with SIMD-width equivalent
179*03ce13f7SAndroid Build Coastguard Worker auto localInvocationIndex = SIMD::Int(subgroupIndex * SIMD::Width) + SIMD::Int(0, 1, 2, 3);
180*03ce13f7SAndroid Build Coastguard Worker
181*03ce13f7SAndroid Build Coastguard Worker // Disable lanes where (invocationIDs >= invocationsPerWorkgroup)
182*03ce13f7SAndroid Build Coastguard Worker auto activeLaneMask = CmpLT(localInvocationIndex, SIMD::Int(invocationsPerWorkgroup));
183*03ce13f7SAndroid Build Coastguard Worker
184*03ce13f7SAndroid Build Coastguard Worker setSubgroupBuiltins(data, routine, workgroupID, localInvocationIndex, subgroupIndex);
185*03ce13f7SAndroid Build Coastguard Worker
186*03ce13f7SAndroid Build Coastguard Worker shader->emit(routine, activeLaneMask, activeLaneMask, descriptorSets);
187*03ce13f7SAndroid Build Coastguard Worker }
188*03ce13f7SAndroid Build Coastguard Worker }
189*03ce13f7SAndroid Build Coastguard Worker
run(const vk::DescriptorSet::Array & descriptorSetObjects,const vk::DescriptorSet::Bindings & descriptorSets,const vk::DescriptorSet::DynamicOffsets & descriptorDynamicOffsets,const vk::Pipeline::PushConstantStorage & pushConstants,uint32_t baseGroupX,uint32_t baseGroupY,uint32_t baseGroupZ,uint32_t groupCountX,uint32_t groupCountY,uint32_t groupCountZ)190*03ce13f7SAndroid Build Coastguard Worker void ComputeProgram::run(
191*03ce13f7SAndroid Build Coastguard Worker const vk::DescriptorSet::Array &descriptorSetObjects,
192*03ce13f7SAndroid Build Coastguard Worker const vk::DescriptorSet::Bindings &descriptorSets,
193*03ce13f7SAndroid Build Coastguard Worker const vk::DescriptorSet::DynamicOffsets &descriptorDynamicOffsets,
194*03ce13f7SAndroid Build Coastguard Worker const vk::Pipeline::PushConstantStorage &pushConstants,
195*03ce13f7SAndroid Build Coastguard Worker uint32_t baseGroupX, uint32_t baseGroupY, uint32_t baseGroupZ,
196*03ce13f7SAndroid Build Coastguard Worker uint32_t groupCountX, uint32_t groupCountY, uint32_t groupCountZ)
197*03ce13f7SAndroid Build Coastguard Worker {
198*03ce13f7SAndroid Build Coastguard Worker uint32_t workgroupSizeX = shader->getWorkgroupSizeX();
199*03ce13f7SAndroid Build Coastguard Worker uint32_t workgroupSizeY = shader->getWorkgroupSizeY();
200*03ce13f7SAndroid Build Coastguard Worker uint32_t workgroupSizeZ = shader->getWorkgroupSizeZ();
201*03ce13f7SAndroid Build Coastguard Worker
202*03ce13f7SAndroid Build Coastguard Worker auto invocationsPerSubgroup = SIMD::Width;
203*03ce13f7SAndroid Build Coastguard Worker auto invocationsPerWorkgroup = workgroupSizeX * workgroupSizeY * workgroupSizeZ;
204*03ce13f7SAndroid Build Coastguard Worker auto subgroupsPerWorkgroup = (invocationsPerWorkgroup + invocationsPerSubgroup - 1) / invocationsPerSubgroup;
205*03ce13f7SAndroid Build Coastguard Worker
206*03ce13f7SAndroid Build Coastguard Worker Data data;
207*03ce13f7SAndroid Build Coastguard Worker data.descriptorSets = descriptorSets;
208*03ce13f7SAndroid Build Coastguard Worker data.descriptorDynamicOffsets = descriptorDynamicOffsets;
209*03ce13f7SAndroid Build Coastguard Worker data.numWorkgroups[0] = groupCountX;
210*03ce13f7SAndroid Build Coastguard Worker data.numWorkgroups[1] = groupCountY;
211*03ce13f7SAndroid Build Coastguard Worker data.numWorkgroups[2] = groupCountZ;
212*03ce13f7SAndroid Build Coastguard Worker data.workgroupSize[0] = workgroupSizeX;
213*03ce13f7SAndroid Build Coastguard Worker data.workgroupSize[1] = workgroupSizeY;
214*03ce13f7SAndroid Build Coastguard Worker data.workgroupSize[2] = workgroupSizeZ;
215*03ce13f7SAndroid Build Coastguard Worker data.invocationsPerSubgroup = invocationsPerSubgroup;
216*03ce13f7SAndroid Build Coastguard Worker data.invocationsPerWorkgroup = invocationsPerWorkgroup;
217*03ce13f7SAndroid Build Coastguard Worker data.subgroupsPerWorkgroup = subgroupsPerWorkgroup;
218*03ce13f7SAndroid Build Coastguard Worker data.pushConstants = pushConstants;
219*03ce13f7SAndroid Build Coastguard Worker
220*03ce13f7SAndroid Build Coastguard Worker marl::WaitGroup wg;
221*03ce13f7SAndroid Build Coastguard Worker constexpr uint32_t batchCount = 16;
222*03ce13f7SAndroid Build Coastguard Worker
223*03ce13f7SAndroid Build Coastguard Worker auto groupCount = groupCountX * groupCountY * groupCountZ;
224*03ce13f7SAndroid Build Coastguard Worker
225*03ce13f7SAndroid Build Coastguard Worker for(uint32_t batchID = 0; batchID < batchCount && batchID < groupCount; batchID++)
226*03ce13f7SAndroid Build Coastguard Worker {
227*03ce13f7SAndroid Build Coastguard Worker wg.add(1);
228*03ce13f7SAndroid Build Coastguard Worker marl::schedule([this, batchID, groupCount, groupCountX, groupCountY,
229*03ce13f7SAndroid Build Coastguard Worker baseGroupZ, baseGroupY, baseGroupX, wg, subgroupsPerWorkgroup,
230*03ce13f7SAndroid Build Coastguard Worker &data] {
231*03ce13f7SAndroid Build Coastguard Worker // Workaround for the fact that some compilers don't allow batchCount to be captured.
232*03ce13f7SAndroid Build Coastguard Worker constexpr uint32_t batchCount = 16;
233*03ce13f7SAndroid Build Coastguard Worker defer(wg.done());
234*03ce13f7SAndroid Build Coastguard Worker std::vector<uint8_t> workgroupMemory(shader->workgroupMemory.size());
235*03ce13f7SAndroid Build Coastguard Worker
236*03ce13f7SAndroid Build Coastguard Worker for(uint32_t groupIndex = batchID; groupIndex < groupCount; groupIndex += batchCount)
237*03ce13f7SAndroid Build Coastguard Worker {
238*03ce13f7SAndroid Build Coastguard Worker auto modulo = groupIndex;
239*03ce13f7SAndroid Build Coastguard Worker auto groupOffsetZ = modulo / (groupCountX * groupCountY);
240*03ce13f7SAndroid Build Coastguard Worker modulo -= groupOffsetZ * (groupCountX * groupCountY);
241*03ce13f7SAndroid Build Coastguard Worker auto groupOffsetY = modulo / groupCountX;
242*03ce13f7SAndroid Build Coastguard Worker modulo -= groupOffsetY * groupCountX;
243*03ce13f7SAndroid Build Coastguard Worker auto groupOffsetX = modulo;
244*03ce13f7SAndroid Build Coastguard Worker
245*03ce13f7SAndroid Build Coastguard Worker auto groupZ = baseGroupZ + groupOffsetZ;
246*03ce13f7SAndroid Build Coastguard Worker auto groupY = baseGroupY + groupOffsetY;
247*03ce13f7SAndroid Build Coastguard Worker auto groupX = baseGroupX + groupOffsetX;
248*03ce13f7SAndroid Build Coastguard Worker MARL_SCOPED_EVENT("groupX: %d, groupY: %d, groupZ: %d", groupX, groupY, groupZ);
249*03ce13f7SAndroid Build Coastguard Worker
250*03ce13f7SAndroid Build Coastguard Worker using Coroutine = std::unique_ptr<rr::Stream<SpirvEmitter::YieldResult>>;
251*03ce13f7SAndroid Build Coastguard Worker std::queue<Coroutine> coroutines;
252*03ce13f7SAndroid Build Coastguard Worker
253*03ce13f7SAndroid Build Coastguard Worker if(shader->getAnalysis().ContainsControlBarriers)
254*03ce13f7SAndroid Build Coastguard Worker {
255*03ce13f7SAndroid Build Coastguard Worker // Make a function call per subgroup so each subgroup
256*03ce13f7SAndroid Build Coastguard Worker // can yield, bringing all subgroups to the barrier
257*03ce13f7SAndroid Build Coastguard Worker // together.
258*03ce13f7SAndroid Build Coastguard Worker for(uint32_t subgroupIndex = 0; subgroupIndex < subgroupsPerWorkgroup; subgroupIndex++)
259*03ce13f7SAndroid Build Coastguard Worker {
260*03ce13f7SAndroid Build Coastguard Worker auto coroutine = (*this)(device, &data, groupX, groupY, groupZ, workgroupMemory.data(), subgroupIndex, 1);
261*03ce13f7SAndroid Build Coastguard Worker coroutines.push(std::move(coroutine));
262*03ce13f7SAndroid Build Coastguard Worker }
263*03ce13f7SAndroid Build Coastguard Worker }
264*03ce13f7SAndroid Build Coastguard Worker else
265*03ce13f7SAndroid Build Coastguard Worker {
266*03ce13f7SAndroid Build Coastguard Worker auto coroutine = (*this)(device, &data, groupX, groupY, groupZ, workgroupMemory.data(), 0, subgroupsPerWorkgroup);
267*03ce13f7SAndroid Build Coastguard Worker coroutines.push(std::move(coroutine));
268*03ce13f7SAndroid Build Coastguard Worker }
269*03ce13f7SAndroid Build Coastguard Worker
270*03ce13f7SAndroid Build Coastguard Worker while(coroutines.size() > 0)
271*03ce13f7SAndroid Build Coastguard Worker {
272*03ce13f7SAndroid Build Coastguard Worker auto coroutine = std::move(coroutines.front());
273*03ce13f7SAndroid Build Coastguard Worker coroutines.pop();
274*03ce13f7SAndroid Build Coastguard Worker
275*03ce13f7SAndroid Build Coastguard Worker SpirvEmitter::YieldResult result;
276*03ce13f7SAndroid Build Coastguard Worker if(coroutine->await(result))
277*03ce13f7SAndroid Build Coastguard Worker {
278*03ce13f7SAndroid Build Coastguard Worker // TODO: Consider result (when the enum is more than 1 entry).
279*03ce13f7SAndroid Build Coastguard Worker coroutines.push(std::move(coroutine));
280*03ce13f7SAndroid Build Coastguard Worker }
281*03ce13f7SAndroid Build Coastguard Worker }
282*03ce13f7SAndroid Build Coastguard Worker }
283*03ce13f7SAndroid Build Coastguard Worker });
284*03ce13f7SAndroid Build Coastguard Worker }
285*03ce13f7SAndroid Build Coastguard Worker
286*03ce13f7SAndroid Build Coastguard Worker wg.wait();
287*03ce13f7SAndroid Build Coastguard Worker
288*03ce13f7SAndroid Build Coastguard Worker if(shader->containsImageWrite())
289*03ce13f7SAndroid Build Coastguard Worker {
290*03ce13f7SAndroid Build Coastguard Worker vk::DescriptorSet::ContentsChanged(descriptorSetObjects, pipelineLayout, device);
291*03ce13f7SAndroid Build Coastguard Worker }
292*03ce13f7SAndroid Build Coastguard Worker }
293*03ce13f7SAndroid Build Coastguard Worker
294*03ce13f7SAndroid Build Coastguard Worker } // namespace sw
295