xref: /aosp_15_r20/external/skia/src/gpu/graphite/mtl/MtlComputePipeline.mm (revision c8dee2aa9b3f27cf6c858bd81872bdeb2c07ed17)
1/*
2 * Copyright 2022 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#include "src/gpu/graphite/mtl/MtlComputePipeline.h"
9
10#include "include/gpu/ShaderErrorHandler.h"
11#include "src/gpu/graphite/ComputePipelineDesc.h"
12#include "src/gpu/graphite/ContextUtils.h"
13#include "src/gpu/graphite/Log.h"
14#include "src/gpu/graphite/ResourceProvider.h"
15#include "src/gpu/graphite/mtl/MtlGraphiteUtilsPriv.h"
16#include "src/gpu/graphite/mtl/MtlSharedContext.h"
17#include "src/gpu/mtl/MtlUtilsPriv.h"
18#include "src/sksl/SkSLCompiler.h"
19#include "src/sksl/SkSLProgramKind.h"
20#include "src/sksl/SkSLProgramSettings.h"
21#include "src/sksl/ir/SkSLProgram.h"
22
23namespace skgpu::graphite {
24
25// static
26sk_sp<MtlComputePipeline> MtlComputePipeline::Make(const MtlSharedContext* sharedContext,
27                                                   const ComputePipelineDesc& pipelineDesc) {
28    sk_cfp<id<MTLLibrary>> library;
29    std::string entryPointName;
30    ShaderErrorHandler* errorHandler = sharedContext->caps()->shaderErrorHandler();
31    if (pipelineDesc.computeStep()->supportsNativeShader()) {
32        auto nativeShader = pipelineDesc.computeStep()->nativeShaderSource(
33                ComputeStep::NativeShaderFormat::kMSL);
34        library = MtlCompileShaderLibrary(sharedContext,
35                                          pipelineDesc.computeStep()->name(),
36                                          nativeShader.fSource,
37                                          errorHandler);
38        if (library == nil) {
39            return nullptr;
40        }
41        entryPointName = std::move(nativeShader.fEntryPoint);
42    } else {
43        std::string msl;
44        SkSL::Program::Interface interface;
45        SkSL::ProgramSettings settings;
46
47        SkSL::Compiler skslCompiler;
48        std::string sksl = BuildComputeSkSL(sharedContext->caps(), pipelineDesc.computeStep());
49        if (!SkSLToMSL(sharedContext->caps()->shaderCaps(),
50                       sksl,
51                       SkSL::ProgramKind::kCompute,
52                       settings,
53                       &msl,
54                       &interface,
55                       errorHandler)) {
56            return nullptr;
57        }
58        library = MtlCompileShaderLibrary(sharedContext,
59                                          pipelineDesc.computeStep()->name(),
60                                          msl,
61                                          errorHandler);
62        if (library == nil) {
63            return nullptr;
64        }
65        entryPointName = "computeMain";
66    }
67
68    sk_cfp<MTLComputePipelineDescriptor*> psoDescriptor([MTLComputePipelineDescriptor new]);
69
70    (*psoDescriptor).label = @(pipelineDesc.computeStep()->name());
71
72    NSString* entryPoint = [NSString stringWithUTF8String:entryPointName.c_str()];
73    (*psoDescriptor).computeFunction = [library.get() newFunctionWithName:entryPoint];
74
75    // TODO(b/240604614): Populate input data attribute and buffer layout descriptors using the
76    // `stageInputDescriptor` property based on the contents of `pipelineDesc` (on iOS 10+ or
77    // macOS 10.12+).
78
79    // TODO(b/240604614): Define input buffer mutability using the `buffers` property based on
80    // the contents of `pipelineDesc` (on iOS 11+ or macOS 10.13+).
81
82    // TODO(b/240615224): Metal docs claim that setting the
83    // `threadGroupSizeIsMultipleOfThreadExecutionWidth` to YES may improve performance, IF we can
84    // guarantee that the thread group size used in a dispatch command is a multiple of
85    // `threadExecutionWidth` property of the pipeline state object (otherwise this will cause UB).
86
87    NSError* error;
88    sk_cfp<id<MTLComputePipelineState>> pso([sharedContext->device()
89            newComputePipelineStateWithDescriptor:psoDescriptor.get()
90                                          options:MTLPipelineOptionNone
91                                       reflection:nil
92                                            error:&error]);
93    if (!pso) {
94        SKGPU_LOG_E("Compute pipeline creation failure:\n%s", error.debugDescription.UTF8String);
95        return nullptr;
96    }
97
98    return sk_sp<MtlComputePipeline>(new MtlComputePipeline(sharedContext, std::move(pso)));
99}
100
101void MtlComputePipeline::freeGpuData() { fPipelineState.reset(); }
102
103}  // namespace skgpu::graphite
104