xref: /aosp_15_r20/external/skia/src/gpu/graphite/dawn/DawnComputePipeline.cpp (revision c8dee2aa9b3f27cf6c858bd81872bdeb2c07ed17)
1*c8dee2aaSAndroid Build Coastguard Worker /*
2*c8dee2aaSAndroid Build Coastguard Worker  * Copyright 2023 Google LLC
3*c8dee2aaSAndroid Build Coastguard Worker  *
4*c8dee2aaSAndroid Build Coastguard Worker  * Use of this source code is governed by a BSD-style license that can be
5*c8dee2aaSAndroid Build Coastguard Worker  * found in the LICENSE file.
6*c8dee2aaSAndroid Build Coastguard Worker  */
7*c8dee2aaSAndroid Build Coastguard Worker 
8*c8dee2aaSAndroid Build Coastguard Worker #include "src/gpu/graphite/dawn/DawnComputePipeline.h"
9*c8dee2aaSAndroid Build Coastguard Worker 
10*c8dee2aaSAndroid Build Coastguard Worker #include "src/gpu/SkSLToBackend.h"
11*c8dee2aaSAndroid Build Coastguard Worker #include "src/gpu/graphite/Caps.h"
12*c8dee2aaSAndroid Build Coastguard Worker #include "src/gpu/graphite/ComputePipelineDesc.h"
13*c8dee2aaSAndroid Build Coastguard Worker #include "src/gpu/graphite/ContextUtils.h"
14*c8dee2aaSAndroid Build Coastguard Worker #include "src/gpu/graphite/dawn/DawnAsyncWait.h"
15*c8dee2aaSAndroid Build Coastguard Worker #include "src/gpu/graphite/dawn/DawnErrorChecker.h"
16*c8dee2aaSAndroid Build Coastguard Worker #include "src/gpu/graphite/dawn/DawnGraphiteTypesPriv.h"
17*c8dee2aaSAndroid Build Coastguard Worker #include "src/gpu/graphite/dawn/DawnGraphiteUtilsPriv.h"
18*c8dee2aaSAndroid Build Coastguard Worker #include "src/gpu/graphite/dawn/DawnSharedContext.h"
19*c8dee2aaSAndroid Build Coastguard Worker #include "src/gpu/graphite/dawn/DawnUtilsPriv.h"
20*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/SkSLProgramSettings.h"
21*c8dee2aaSAndroid Build Coastguard Worker 
22*c8dee2aaSAndroid Build Coastguard Worker namespace skgpu::graphite {
23*c8dee2aaSAndroid Build Coastguard Worker namespace {
24*c8dee2aaSAndroid Build Coastguard Worker 
25*c8dee2aaSAndroid Build Coastguard Worker struct ShaderInfo {
26*c8dee2aaSAndroid Build Coastguard Worker     wgpu::ShaderModule fModule;
27*c8dee2aaSAndroid Build Coastguard Worker     std::string fEntryPoint;
28*c8dee2aaSAndroid Build Coastguard Worker 
isValidskgpu::graphite::__anon992c045a0111::ShaderInfo29*c8dee2aaSAndroid Build Coastguard Worker     bool isValid() const { return static_cast<bool>(fModule); }
30*c8dee2aaSAndroid Build Coastguard Worker };
31*c8dee2aaSAndroid Build Coastguard Worker 
compile_shader_module(const DawnSharedContext * sharedContext,const ComputePipelineDesc & pipelineDesc)32*c8dee2aaSAndroid Build Coastguard Worker static ShaderInfo compile_shader_module(const DawnSharedContext* sharedContext,
33*c8dee2aaSAndroid Build Coastguard Worker                                         const ComputePipelineDesc& pipelineDesc) {
34*c8dee2aaSAndroid Build Coastguard Worker     SkASSERT(sharedContext);
35*c8dee2aaSAndroid Build Coastguard Worker 
36*c8dee2aaSAndroid Build Coastguard Worker     ShaderInfo info;
37*c8dee2aaSAndroid Build Coastguard Worker 
38*c8dee2aaSAndroid Build Coastguard Worker     const Caps* caps = sharedContext->caps();
39*c8dee2aaSAndroid Build Coastguard Worker     const ComputeStep* step = pipelineDesc.computeStep();
40*c8dee2aaSAndroid Build Coastguard Worker     ShaderErrorHandler* errorHandler = caps->shaderErrorHandler();
41*c8dee2aaSAndroid Build Coastguard Worker 
42*c8dee2aaSAndroid Build Coastguard Worker     if (step->supportsNativeShader()) {
43*c8dee2aaSAndroid Build Coastguard Worker         auto nativeShader = step->nativeShaderSource(ComputeStep::NativeShaderFormat::kWGSL);
44*c8dee2aaSAndroid Build Coastguard Worker         if (!DawnCompileWGSLShaderModule(sharedContext,
45*c8dee2aaSAndroid Build Coastguard Worker                                          step->name(),
46*c8dee2aaSAndroid Build Coastguard Worker                                          std::string(nativeShader.fSource),
47*c8dee2aaSAndroid Build Coastguard Worker                                          &info.fModule,
48*c8dee2aaSAndroid Build Coastguard Worker                                          errorHandler)) {
49*c8dee2aaSAndroid Build Coastguard Worker             return {};
50*c8dee2aaSAndroid Build Coastguard Worker         }
51*c8dee2aaSAndroid Build Coastguard Worker         info.fEntryPoint = std::move(nativeShader.fEntryPoint);
52*c8dee2aaSAndroid Build Coastguard Worker     } else {
53*c8dee2aaSAndroid Build Coastguard Worker         std::string wgsl;
54*c8dee2aaSAndroid Build Coastguard Worker         SkSL::Program::Interface interface;
55*c8dee2aaSAndroid Build Coastguard Worker         SkSL::ProgramSettings settings;
56*c8dee2aaSAndroid Build Coastguard Worker 
57*c8dee2aaSAndroid Build Coastguard Worker         std::string sksl = BuildComputeSkSL(caps, step);
58*c8dee2aaSAndroid Build Coastguard Worker         if (skgpu::SkSLToWGSL(caps->shaderCaps(),
59*c8dee2aaSAndroid Build Coastguard Worker                               sksl,
60*c8dee2aaSAndroid Build Coastguard Worker                               SkSL::ProgramKind::kCompute,
61*c8dee2aaSAndroid Build Coastguard Worker                               settings,
62*c8dee2aaSAndroid Build Coastguard Worker                               &wgsl,
63*c8dee2aaSAndroid Build Coastguard Worker                               &interface,
64*c8dee2aaSAndroid Build Coastguard Worker                               errorHandler)) {
65*c8dee2aaSAndroid Build Coastguard Worker             if (!DawnCompileWGSLShaderModule(sharedContext, step->name(), wgsl,
66*c8dee2aaSAndroid Build Coastguard Worker                                              &info.fModule, errorHandler)) {
67*c8dee2aaSAndroid Build Coastguard Worker                 return {};
68*c8dee2aaSAndroid Build Coastguard Worker             }
69*c8dee2aaSAndroid Build Coastguard Worker             info.fEntryPoint = "main";
70*c8dee2aaSAndroid Build Coastguard Worker         }
71*c8dee2aaSAndroid Build Coastguard Worker     }
72*c8dee2aaSAndroid Build Coastguard Worker 
73*c8dee2aaSAndroid Build Coastguard Worker     return info;
74*c8dee2aaSAndroid Build Coastguard Worker }
75*c8dee2aaSAndroid Build Coastguard Worker 
76*c8dee2aaSAndroid Build Coastguard Worker }  // namespace
77*c8dee2aaSAndroid Build Coastguard Worker 
Make(const DawnSharedContext * sharedContext,const ComputePipelineDesc & pipelineDesc)78*c8dee2aaSAndroid Build Coastguard Worker sk_sp<DawnComputePipeline> DawnComputePipeline::Make(const DawnSharedContext* sharedContext,
79*c8dee2aaSAndroid Build Coastguard Worker                                                      const ComputePipelineDesc& pipelineDesc) {
80*c8dee2aaSAndroid Build Coastguard Worker     auto [shaderModule, entryPointName] = compile_shader_module(sharedContext, pipelineDesc);
81*c8dee2aaSAndroid Build Coastguard Worker     if (!shaderModule) {
82*c8dee2aaSAndroid Build Coastguard Worker         return nullptr;
83*c8dee2aaSAndroid Build Coastguard Worker     }
84*c8dee2aaSAndroid Build Coastguard Worker 
85*c8dee2aaSAndroid Build Coastguard Worker     const ComputeStep* step = pipelineDesc.computeStep();
86*c8dee2aaSAndroid Build Coastguard Worker 
87*c8dee2aaSAndroid Build Coastguard Worker     // ComputeStep resources are listed in the order that they must be declared in the shader. This
88*c8dee2aaSAndroid Build Coastguard Worker     // order is then used for the index assignment using an "indexed by order" policy that has
89*c8dee2aaSAndroid Build Coastguard Worker     // backend-specific semantics. The semantics on Dawn is to assign the index number in increasing
90*c8dee2aaSAndroid Build Coastguard Worker     // order.
91*c8dee2aaSAndroid Build Coastguard Worker     //
92*c8dee2aaSAndroid Build Coastguard Worker     // All resources get assigned to a single bind group at index 0.
93*c8dee2aaSAndroid Build Coastguard Worker     SkASSERT(!sharedContext->caps()->resourceBindingRequirements().fDistinctIndexRanges);
94*c8dee2aaSAndroid Build Coastguard Worker     std::vector<wgpu::BindGroupLayoutEntry> bindGroupLayoutEntries;
95*c8dee2aaSAndroid Build Coastguard Worker     auto resources = step->resources();
96*c8dee2aaSAndroid Build Coastguard Worker 
97*c8dee2aaSAndroid Build Coastguard Worker     // Sampled textures count as 2 resources (1 texture and 1 sampler). All other types count as 1.
98*c8dee2aaSAndroid Build Coastguard Worker     size_t resourceCount = 0;
99*c8dee2aaSAndroid Build Coastguard Worker     for (const ComputeStep::ResourceDesc& r : resources) {
100*c8dee2aaSAndroid Build Coastguard Worker         resourceCount++;
101*c8dee2aaSAndroid Build Coastguard Worker         if (r.fType == ComputeStep::ResourceType::kSampledTexture) {
102*c8dee2aaSAndroid Build Coastguard Worker             resourceCount++;
103*c8dee2aaSAndroid Build Coastguard Worker         }
104*c8dee2aaSAndroid Build Coastguard Worker     }
105*c8dee2aaSAndroid Build Coastguard Worker 
106*c8dee2aaSAndroid Build Coastguard Worker     bindGroupLayoutEntries.reserve(resourceCount);
107*c8dee2aaSAndroid Build Coastguard Worker     int declarationIndex = 0;
108*c8dee2aaSAndroid Build Coastguard Worker     for (const ComputeStep::ResourceDesc& r : resources) {
109*c8dee2aaSAndroid Build Coastguard Worker         bindGroupLayoutEntries.emplace_back();
110*c8dee2aaSAndroid Build Coastguard Worker         uint32_t bindingIndex = bindGroupLayoutEntries.size() - 1;
111*c8dee2aaSAndroid Build Coastguard Worker 
112*c8dee2aaSAndroid Build Coastguard Worker         wgpu::BindGroupLayoutEntry& entry = bindGroupLayoutEntries.back();
113*c8dee2aaSAndroid Build Coastguard Worker         entry.binding = bindingIndex;
114*c8dee2aaSAndroid Build Coastguard Worker         entry.visibility = wgpu::ShaderStage::Compute;
115*c8dee2aaSAndroid Build Coastguard Worker         switch (r.fType) {
116*c8dee2aaSAndroid Build Coastguard Worker             case ComputeStep::ResourceType::kUniformBuffer:
117*c8dee2aaSAndroid Build Coastguard Worker                 entry.buffer.type = wgpu::BufferBindingType::Uniform;
118*c8dee2aaSAndroid Build Coastguard Worker                 break;
119*c8dee2aaSAndroid Build Coastguard Worker             case ComputeStep::ResourceType::kStorageBuffer:
120*c8dee2aaSAndroid Build Coastguard Worker             case ComputeStep::ResourceType::kIndirectBuffer:
121*c8dee2aaSAndroid Build Coastguard Worker                 entry.buffer.type = wgpu::BufferBindingType::Storage;
122*c8dee2aaSAndroid Build Coastguard Worker                 break;
123*c8dee2aaSAndroid Build Coastguard Worker             case ComputeStep::ResourceType::kReadOnlyStorageBuffer:
124*c8dee2aaSAndroid Build Coastguard Worker                 entry.buffer.type = wgpu::BufferBindingType::ReadOnlyStorage;
125*c8dee2aaSAndroid Build Coastguard Worker                 break;
126*c8dee2aaSAndroid Build Coastguard Worker             case ComputeStep::ResourceType::kReadOnlyTexture:
127*c8dee2aaSAndroid Build Coastguard Worker                 entry.texture.sampleType = wgpu::TextureSampleType::Float;
128*c8dee2aaSAndroid Build Coastguard Worker                 entry.texture.viewDimension = wgpu::TextureViewDimension::e2D;
129*c8dee2aaSAndroid Build Coastguard Worker                 break;
130*c8dee2aaSAndroid Build Coastguard Worker             case ComputeStep::ResourceType::kWriteOnlyStorageTexture: {
131*c8dee2aaSAndroid Build Coastguard Worker                 entry.storageTexture.access = wgpu::StorageTextureAccess::WriteOnly;
132*c8dee2aaSAndroid Build Coastguard Worker                 entry.storageTexture.viewDimension = wgpu::TextureViewDimension::e2D;
133*c8dee2aaSAndroid Build Coastguard Worker 
134*c8dee2aaSAndroid Build Coastguard Worker                 auto [_, colorType] = step->calculateTextureParameters(declarationIndex, r);
135*c8dee2aaSAndroid Build Coastguard Worker                 auto textureInfo = sharedContext->caps()->getDefaultStorageTextureInfo(colorType);
136*c8dee2aaSAndroid Build Coastguard Worker                 entry.storageTexture.format = TextureInfos::GetDawnViewFormat(textureInfo);
137*c8dee2aaSAndroid Build Coastguard Worker                 break;
138*c8dee2aaSAndroid Build Coastguard Worker             }
139*c8dee2aaSAndroid Build Coastguard Worker             case ComputeStep::ResourceType::kSampledTexture: {
140*c8dee2aaSAndroid Build Coastguard Worker                 entry.sampler.type = wgpu::SamplerBindingType::Filtering;
141*c8dee2aaSAndroid Build Coastguard Worker 
142*c8dee2aaSAndroid Build Coastguard Worker                 // Add an additional entry for the texture.
143*c8dee2aaSAndroid Build Coastguard Worker                 bindGroupLayoutEntries.emplace_back();
144*c8dee2aaSAndroid Build Coastguard Worker                 wgpu::BindGroupLayoutEntry& texEntry = bindGroupLayoutEntries.back();
145*c8dee2aaSAndroid Build Coastguard Worker                 texEntry.binding = bindingIndex + 1;
146*c8dee2aaSAndroid Build Coastguard Worker                 texEntry.visibility = wgpu::ShaderStage::Compute;
147*c8dee2aaSAndroid Build Coastguard Worker                 texEntry.texture.sampleType = wgpu::TextureSampleType::Float;
148*c8dee2aaSAndroid Build Coastguard Worker                 texEntry.texture.viewDimension = wgpu::TextureViewDimension::e2D;
149*c8dee2aaSAndroid Build Coastguard Worker                 break;
150*c8dee2aaSAndroid Build Coastguard Worker             }
151*c8dee2aaSAndroid Build Coastguard Worker         }
152*c8dee2aaSAndroid Build Coastguard Worker         declarationIndex++;
153*c8dee2aaSAndroid Build Coastguard Worker     }
154*c8dee2aaSAndroid Build Coastguard Worker 
155*c8dee2aaSAndroid Build Coastguard Worker     const wgpu::Device& device = sharedContext->device();
156*c8dee2aaSAndroid Build Coastguard Worker 
157*c8dee2aaSAndroid Build Coastguard Worker     // All resources of a ComputeStep currently get assigned to a single bind group at index 0.
158*c8dee2aaSAndroid Build Coastguard Worker     wgpu::BindGroupLayoutDescriptor bindGroupLayoutDesc;
159*c8dee2aaSAndroid Build Coastguard Worker     bindGroupLayoutDesc.entryCount = bindGroupLayoutEntries.size();
160*c8dee2aaSAndroid Build Coastguard Worker     bindGroupLayoutDesc.entries = bindGroupLayoutEntries.data();
161*c8dee2aaSAndroid Build Coastguard Worker     wgpu::BindGroupLayout bindGroupLayout = device.CreateBindGroupLayout(&bindGroupLayoutDesc);
162*c8dee2aaSAndroid Build Coastguard Worker     if (!bindGroupLayout) {
163*c8dee2aaSAndroid Build Coastguard Worker         return nullptr;
164*c8dee2aaSAndroid Build Coastguard Worker     }
165*c8dee2aaSAndroid Build Coastguard Worker 
166*c8dee2aaSAndroid Build Coastguard Worker     wgpu::PipelineLayoutDescriptor pipelineLayoutDesc;
167*c8dee2aaSAndroid Build Coastguard Worker     if (sharedContext->caps()->setBackendLabels()) {
168*c8dee2aaSAndroid Build Coastguard Worker         pipelineLayoutDesc.label = step->name();
169*c8dee2aaSAndroid Build Coastguard Worker     }
170*c8dee2aaSAndroid Build Coastguard Worker     pipelineLayoutDesc.bindGroupLayoutCount = 1;
171*c8dee2aaSAndroid Build Coastguard Worker     pipelineLayoutDesc.bindGroupLayouts = &bindGroupLayout;
172*c8dee2aaSAndroid Build Coastguard Worker     wgpu::PipelineLayout layout = device.CreatePipelineLayout(&pipelineLayoutDesc);
173*c8dee2aaSAndroid Build Coastguard Worker     if (!layout) {
174*c8dee2aaSAndroid Build Coastguard Worker         return nullptr;
175*c8dee2aaSAndroid Build Coastguard Worker     }
176*c8dee2aaSAndroid Build Coastguard Worker 
177*c8dee2aaSAndroid Build Coastguard Worker     wgpu::ComputePipelineDescriptor descriptor;
178*c8dee2aaSAndroid Build Coastguard Worker     // Always set the label for pipelines, dawn may need it for tracing.
179*c8dee2aaSAndroid Build Coastguard Worker     descriptor.label = step->name();
180*c8dee2aaSAndroid Build Coastguard Worker     descriptor.compute.module = std::move(shaderModule);
181*c8dee2aaSAndroid Build Coastguard Worker     descriptor.compute.entryPoint = entryPointName.c_str();
182*c8dee2aaSAndroid Build Coastguard Worker     descriptor.layout = std::move(layout);
183*c8dee2aaSAndroid Build Coastguard Worker 
184*c8dee2aaSAndroid Build Coastguard Worker     std::optional<DawnErrorChecker> errorChecker;
185*c8dee2aaSAndroid Build Coastguard Worker     if (sharedContext->dawnCaps()->allowScopedErrorChecks()) {
186*c8dee2aaSAndroid Build Coastguard Worker         errorChecker.emplace(sharedContext);
187*c8dee2aaSAndroid Build Coastguard Worker     }
188*c8dee2aaSAndroid Build Coastguard Worker     wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&descriptor);
189*c8dee2aaSAndroid Build Coastguard Worker     SkASSERT(pipeline);
190*c8dee2aaSAndroid Build Coastguard Worker     if (errorChecker.has_value() && errorChecker->popErrorScopes() != DawnErrorType::kNoError) {
191*c8dee2aaSAndroid Build Coastguard Worker         return nullptr;
192*c8dee2aaSAndroid Build Coastguard Worker     }
193*c8dee2aaSAndroid Build Coastguard Worker 
194*c8dee2aaSAndroid Build Coastguard Worker     return sk_sp<DawnComputePipeline>(new DawnComputePipeline(
195*c8dee2aaSAndroid Build Coastguard Worker             sharedContext, std::move(pipeline), std::move(bindGroupLayout)));
196*c8dee2aaSAndroid Build Coastguard Worker }
197*c8dee2aaSAndroid Build Coastguard Worker 
DawnComputePipeline(const SharedContext * sharedContext,wgpu::ComputePipeline pso,wgpu::BindGroupLayout groupLayout)198*c8dee2aaSAndroid Build Coastguard Worker DawnComputePipeline::DawnComputePipeline(const SharedContext* sharedContext,
199*c8dee2aaSAndroid Build Coastguard Worker                                          wgpu::ComputePipeline pso,
200*c8dee2aaSAndroid Build Coastguard Worker                                          wgpu::BindGroupLayout groupLayout)
201*c8dee2aaSAndroid Build Coastguard Worker         : ComputePipeline(sharedContext)
202*c8dee2aaSAndroid Build Coastguard Worker         , fPipeline(std::move(pso))
203*c8dee2aaSAndroid Build Coastguard Worker         , fGroupLayout(std::move(groupLayout)) {}
204*c8dee2aaSAndroid Build Coastguard Worker 
freeGpuData()205*c8dee2aaSAndroid Build Coastguard Worker void DawnComputePipeline::freeGpuData() { fPipeline = nullptr; }
206*c8dee2aaSAndroid Build Coastguard Worker 
207*c8dee2aaSAndroid Build Coastguard Worker }  // namespace skgpu::graphite
208