xref: /aosp_15_r20/external/skia/src/gpu/graphite/mtl/MtlComputePipeline.mm (revision c8dee2aa9b3f27cf6c858bd81872bdeb2c07ed17)
1*c8dee2aaSAndroid Build Coastguard Worker/*
2*c8dee2aaSAndroid Build Coastguard Worker * Copyright 2022 Google LLC
3*c8dee2aaSAndroid Build Coastguard Worker *
4*c8dee2aaSAndroid Build Coastguard Worker * Use of this source code is governed by a BSD-style license that can be
5*c8dee2aaSAndroid Build Coastguard Worker * found in the LICENSE file.
6*c8dee2aaSAndroid Build Coastguard Worker */
7*c8dee2aaSAndroid Build Coastguard Worker
8*c8dee2aaSAndroid Build Coastguard Worker#include "src/gpu/graphite/mtl/MtlComputePipeline.h"
9*c8dee2aaSAndroid Build Coastguard Worker
10*c8dee2aaSAndroid Build Coastguard Worker#include "include/gpu/ShaderErrorHandler.h"
11*c8dee2aaSAndroid Build Coastguard Worker#include "src/gpu/graphite/ComputePipelineDesc.h"
12*c8dee2aaSAndroid Build Coastguard Worker#include "src/gpu/graphite/ContextUtils.h"
13*c8dee2aaSAndroid Build Coastguard Worker#include "src/gpu/graphite/Log.h"
14*c8dee2aaSAndroid Build Coastguard Worker#include "src/gpu/graphite/ResourceProvider.h"
15*c8dee2aaSAndroid Build Coastguard Worker#include "src/gpu/graphite/mtl/MtlGraphiteUtilsPriv.h"
16*c8dee2aaSAndroid Build Coastguard Worker#include "src/gpu/graphite/mtl/MtlSharedContext.h"
17*c8dee2aaSAndroid Build Coastguard Worker#include "src/gpu/mtl/MtlUtilsPriv.h"
18*c8dee2aaSAndroid Build Coastguard Worker#include "src/sksl/SkSLCompiler.h"
19*c8dee2aaSAndroid Build Coastguard Worker#include "src/sksl/SkSLProgramKind.h"
20*c8dee2aaSAndroid Build Coastguard Worker#include "src/sksl/SkSLProgramSettings.h"
21*c8dee2aaSAndroid Build Coastguard Worker#include "src/sksl/ir/SkSLProgram.h"
22*c8dee2aaSAndroid Build Coastguard Worker
23*c8dee2aaSAndroid Build Coastguard Workernamespace skgpu::graphite {
24*c8dee2aaSAndroid Build Coastguard Worker
25*c8dee2aaSAndroid Build Coastguard Worker// static
26*c8dee2aaSAndroid Build Coastguard Workersk_sp<MtlComputePipeline> MtlComputePipeline::Make(const MtlSharedContext* sharedContext,
27*c8dee2aaSAndroid Build Coastguard Worker                                                   const ComputePipelineDesc& pipelineDesc) {
28*c8dee2aaSAndroid Build Coastguard Worker    sk_cfp<id<MTLLibrary>> library;
29*c8dee2aaSAndroid Build Coastguard Worker    std::string entryPointName;
30*c8dee2aaSAndroid Build Coastguard Worker    ShaderErrorHandler* errorHandler = sharedContext->caps()->shaderErrorHandler();
31*c8dee2aaSAndroid Build Coastguard Worker    if (pipelineDesc.computeStep()->supportsNativeShader()) {
32*c8dee2aaSAndroid Build Coastguard Worker        auto nativeShader = pipelineDesc.computeStep()->nativeShaderSource(
33*c8dee2aaSAndroid Build Coastguard Worker                ComputeStep::NativeShaderFormat::kMSL);
34*c8dee2aaSAndroid Build Coastguard Worker        library = MtlCompileShaderLibrary(sharedContext,
35*c8dee2aaSAndroid Build Coastguard Worker                                          pipelineDesc.computeStep()->name(),
36*c8dee2aaSAndroid Build Coastguard Worker                                          nativeShader.fSource,
37*c8dee2aaSAndroid Build Coastguard Worker                                          errorHandler);
38*c8dee2aaSAndroid Build Coastguard Worker        if (library == nil) {
39*c8dee2aaSAndroid Build Coastguard Worker            return nullptr;
40*c8dee2aaSAndroid Build Coastguard Worker        }
41*c8dee2aaSAndroid Build Coastguard Worker        entryPointName = std::move(nativeShader.fEntryPoint);
42*c8dee2aaSAndroid Build Coastguard Worker    } else {
43*c8dee2aaSAndroid Build Coastguard Worker        std::string msl;
44*c8dee2aaSAndroid Build Coastguard Worker        SkSL::Program::Interface interface;
45*c8dee2aaSAndroid Build Coastguard Worker        SkSL::ProgramSettings settings;
46*c8dee2aaSAndroid Build Coastguard Worker
47*c8dee2aaSAndroid Build Coastguard Worker        SkSL::Compiler skslCompiler;
48*c8dee2aaSAndroid Build Coastguard Worker        std::string sksl = BuildComputeSkSL(sharedContext->caps(), pipelineDesc.computeStep());
49*c8dee2aaSAndroid Build Coastguard Worker        if (!SkSLToMSL(sharedContext->caps()->shaderCaps(),
50*c8dee2aaSAndroid Build Coastguard Worker                       sksl,
51*c8dee2aaSAndroid Build Coastguard Worker                       SkSL::ProgramKind::kCompute,
52*c8dee2aaSAndroid Build Coastguard Worker                       settings,
53*c8dee2aaSAndroid Build Coastguard Worker                       &msl,
54*c8dee2aaSAndroid Build Coastguard Worker                       &interface,
55*c8dee2aaSAndroid Build Coastguard Worker                       errorHandler)) {
56*c8dee2aaSAndroid Build Coastguard Worker            return nullptr;
57*c8dee2aaSAndroid Build Coastguard Worker        }
58*c8dee2aaSAndroid Build Coastguard Worker        library = MtlCompileShaderLibrary(sharedContext,
59*c8dee2aaSAndroid Build Coastguard Worker                                          pipelineDesc.computeStep()->name(),
60*c8dee2aaSAndroid Build Coastguard Worker                                          msl,
61*c8dee2aaSAndroid Build Coastguard Worker                                          errorHandler);
62*c8dee2aaSAndroid Build Coastguard Worker        if (library == nil) {
63*c8dee2aaSAndroid Build Coastguard Worker            return nullptr;
64*c8dee2aaSAndroid Build Coastguard Worker        }
65*c8dee2aaSAndroid Build Coastguard Worker        entryPointName = "computeMain";
66*c8dee2aaSAndroid Build Coastguard Worker    }
67*c8dee2aaSAndroid Build Coastguard Worker
68*c8dee2aaSAndroid Build Coastguard Worker    sk_cfp<MTLComputePipelineDescriptor*> psoDescriptor([MTLComputePipelineDescriptor new]);
69*c8dee2aaSAndroid Build Coastguard Worker
70*c8dee2aaSAndroid Build Coastguard Worker    (*psoDescriptor).label = @(pipelineDesc.computeStep()->name());
71*c8dee2aaSAndroid Build Coastguard Worker
72*c8dee2aaSAndroid Build Coastguard Worker    NSString* entryPoint = [NSString stringWithUTF8String:entryPointName.c_str()];
73*c8dee2aaSAndroid Build Coastguard Worker    (*psoDescriptor).computeFunction = [library.get() newFunctionWithName:entryPoint];
74*c8dee2aaSAndroid Build Coastguard Worker
75*c8dee2aaSAndroid Build Coastguard Worker    // TODO(b/240604614): Populate input data attribute and buffer layout descriptors using the
76*c8dee2aaSAndroid Build Coastguard Worker    // `stageInputDescriptor` property based on the contents of `pipelineDesc` (on iOS 10+ or
77*c8dee2aaSAndroid Build Coastguard Worker    // macOS 10.12+).
78*c8dee2aaSAndroid Build Coastguard Worker
79*c8dee2aaSAndroid Build Coastguard Worker    // TODO(b/240604614): Define input buffer mutability using the `buffers` property based on
80*c8dee2aaSAndroid Build Coastguard Worker    // the contents of `pipelineDesc` (on iOS 11+ or macOS 10.13+).
81*c8dee2aaSAndroid Build Coastguard Worker
82*c8dee2aaSAndroid Build Coastguard Worker    // TODO(b/240615224): Metal docs claim that setting the
83*c8dee2aaSAndroid Build Coastguard Worker    // `threadGroupSizeIsMultipleOfThreadExecutionWidth` to YES may improve performance, IF we can
84*c8dee2aaSAndroid Build Coastguard Worker    // guarantee that the thread group size used in a dispatch command is a multiple of
85*c8dee2aaSAndroid Build Coastguard Worker    // `threadExecutionWidth` property of the pipeline state object (otherwise this will cause UB).
86*c8dee2aaSAndroid Build Coastguard Worker
87*c8dee2aaSAndroid Build Coastguard Worker    NSError* error;
88*c8dee2aaSAndroid Build Coastguard Worker    sk_cfp<id<MTLComputePipelineState>> pso([sharedContext->device()
89*c8dee2aaSAndroid Build Coastguard Worker            newComputePipelineStateWithDescriptor:psoDescriptor.get()
90*c8dee2aaSAndroid Build Coastguard Worker                                          options:MTLPipelineOptionNone
91*c8dee2aaSAndroid Build Coastguard Worker                                       reflection:nil
92*c8dee2aaSAndroid Build Coastguard Worker                                            error:&error]);
93*c8dee2aaSAndroid Build Coastguard Worker    if (!pso) {
94*c8dee2aaSAndroid Build Coastguard Worker        SKGPU_LOG_E("Compute pipeline creation failure:\n%s", error.debugDescription.UTF8String);
95*c8dee2aaSAndroid Build Coastguard Worker        return nullptr;
96*c8dee2aaSAndroid Build Coastguard Worker    }
97*c8dee2aaSAndroid Build Coastguard Worker
98*c8dee2aaSAndroid Build Coastguard Worker    return sk_sp<MtlComputePipeline>(new MtlComputePipeline(sharedContext, std::move(pso)));
99*c8dee2aaSAndroid Build Coastguard Worker}
100*c8dee2aaSAndroid Build Coastguard Worker
101*c8dee2aaSAndroid Build Coastguard Workervoid MtlComputePipeline::freeGpuData() { fPipelineState.reset(); }
102*c8dee2aaSAndroid Build Coastguard Worker
103*c8dee2aaSAndroid Build Coastguard Worker}  // namespace skgpu::graphite
104