1*c8dee2aaSAndroid Build Coastguard Worker /*
2*c8dee2aaSAndroid Build Coastguard Worker * Copyright 2024 Google LLC
3*c8dee2aaSAndroid Build Coastguard Worker *
4*c8dee2aaSAndroid Build Coastguard Worker * Use of this source code is governed by a BSD-style license that can be
5*c8dee2aaSAndroid Build Coastguard Worker * found in the LICENSE file.
6*c8dee2aaSAndroid Build Coastguard Worker */
7*c8dee2aaSAndroid Build Coastguard Worker
8*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/analysis/SkSLSpecialization.h"
9*c8dee2aaSAndroid Build Coastguard Worker
10*c8dee2aaSAndroid Build Coastguard Worker #include "include/private/base/SkAssert.h"
11*c8dee2aaSAndroid Build Coastguard Worker #include "include/private/base/SkSpan_impl.h"
12*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/SkSLAnalysis.h"
13*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/SkSLDefines.h"
14*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/analysis/SkSLProgramVisitor.h"
15*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/ir/SkSLExpression.h"
16*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/ir/SkSLFieldAccess.h"
17*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/ir/SkSLFunctionCall.h"
18*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/ir/SkSLFunctionDeclaration.h"
19*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/ir/SkSLFunctionDefinition.h"
20*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/ir/SkSLProgram.h"
21*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/ir/SkSLProgramElement.h"
22*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/ir/SkSLVariable.h"
23*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/ir/SkSLVariableReference.h"
24*c8dee2aaSAndroid Build Coastguard Worker
25*c8dee2aaSAndroid Build Coastguard Worker #include <algorithm>
26*c8dee2aaSAndroid Build Coastguard Worker #include <memory>
27*c8dee2aaSAndroid Build Coastguard Worker
28*c8dee2aaSAndroid Build Coastguard Worker using namespace skia_private;
29*c8dee2aaSAndroid Build Coastguard Worker
30*c8dee2aaSAndroid Build Coastguard Worker namespace SkSL::Analysis {
31*c8dee2aaSAndroid Build Coastguard Worker
parameter_mappings_are_equal(const SpecializedParameters & left,const SpecializedParameters & right)32*c8dee2aaSAndroid Build Coastguard Worker static bool parameter_mappings_are_equal(const SpecializedParameters& left,
33*c8dee2aaSAndroid Build Coastguard Worker const SpecializedParameters& right) {
34*c8dee2aaSAndroid Build Coastguard Worker if (left.count() != right.count()) {
35*c8dee2aaSAndroid Build Coastguard Worker return false;
36*c8dee2aaSAndroid Build Coastguard Worker }
37*c8dee2aaSAndroid Build Coastguard Worker for (const auto& [key, leftExpr] : left) {
38*c8dee2aaSAndroid Build Coastguard Worker const Expression** rightExpr = right.find(key);
39*c8dee2aaSAndroid Build Coastguard Worker if (!rightExpr) {
40*c8dee2aaSAndroid Build Coastguard Worker return false;
41*c8dee2aaSAndroid Build Coastguard Worker }
42*c8dee2aaSAndroid Build Coastguard Worker if (!Analysis::IsSameExpressionTree(*leftExpr, **rightExpr)) {
43*c8dee2aaSAndroid Build Coastguard Worker return false;
44*c8dee2aaSAndroid Build Coastguard Worker }
45*c8dee2aaSAndroid Build Coastguard Worker }
46*c8dee2aaSAndroid Build Coastguard Worker return true;
47*c8dee2aaSAndroid Build Coastguard Worker }
48*c8dee2aaSAndroid Build Coastguard Worker
FindFunctionsToSpecialize(const Program & program,SpecializationInfo * info,const ParameterMatchesFn & parameterMatchesFn)49*c8dee2aaSAndroid Build Coastguard Worker void FindFunctionsToSpecialize(const Program& program,
50*c8dee2aaSAndroid Build Coastguard Worker SpecializationInfo* info,
51*c8dee2aaSAndroid Build Coastguard Worker const ParameterMatchesFn& parameterMatchesFn) {
52*c8dee2aaSAndroid Build Coastguard Worker class Searcher : public ProgramVisitor {
53*c8dee2aaSAndroid Build Coastguard Worker public:
54*c8dee2aaSAndroid Build Coastguard Worker using ProgramVisitor::visitProgramElement;
55*c8dee2aaSAndroid Build Coastguard Worker using INHERITED = ProgramVisitor;
56*c8dee2aaSAndroid Build Coastguard Worker
57*c8dee2aaSAndroid Build Coastguard Worker Searcher(SpecializationInfo& info, const ParameterMatchesFn& parameterMatchesFn)
58*c8dee2aaSAndroid Build Coastguard Worker : fSpecializationMap(info.fSpecializationMap)
59*c8dee2aaSAndroid Build Coastguard Worker , fSpecializedCallMap(info.fSpecializedCallMap)
60*c8dee2aaSAndroid Build Coastguard Worker , fParameterMatchesFn(parameterMatchesFn) {}
61*c8dee2aaSAndroid Build Coastguard Worker
62*c8dee2aaSAndroid Build Coastguard Worker bool visitExpression(const Expression& expr) override {
63*c8dee2aaSAndroid Build Coastguard Worker if (expr.is<FunctionCall>()) {
64*c8dee2aaSAndroid Build Coastguard Worker const FunctionCall& call = expr.as<FunctionCall>();
65*c8dee2aaSAndroid Build Coastguard Worker const FunctionDeclaration& decl = call.function();
66*c8dee2aaSAndroid Build Coastguard Worker
67*c8dee2aaSAndroid Build Coastguard Worker if (!decl.isIntrinsic()) {
68*c8dee2aaSAndroid Build Coastguard Worker SpecializedParameters specialization;
69*c8dee2aaSAndroid Build Coastguard Worker
70*c8dee2aaSAndroid Build Coastguard Worker const int numParams = decl.parameters().size();
71*c8dee2aaSAndroid Build Coastguard Worker SkASSERT(call.arguments().size() == numParams);
72*c8dee2aaSAndroid Build Coastguard Worker
73*c8dee2aaSAndroid Build Coastguard Worker for (int i = 0; i < numParams; i++) {
74*c8dee2aaSAndroid Build Coastguard Worker const Expression& arg = *call.arguments()[i];
75*c8dee2aaSAndroid Build Coastguard Worker
76*c8dee2aaSAndroid Build Coastguard Worker // Specializations can only be made on arguments that are not complex
77*c8dee2aaSAndroid Build Coastguard Worker // expressions but only a variable reference or field access since these
78*c8dee2aaSAndroid Build Coastguard Worker // references will be inlined in the generated specialized functions.
79*c8dee2aaSAndroid Build Coastguard Worker const Variable* argBase = nullptr;
80*c8dee2aaSAndroid Build Coastguard Worker if (arg.is<VariableReference>()) {
81*c8dee2aaSAndroid Build Coastguard Worker argBase = arg.as<VariableReference>().variable();
82*c8dee2aaSAndroid Build Coastguard Worker } else if (arg.is<FieldAccess>() &&
83*c8dee2aaSAndroid Build Coastguard Worker arg.as<FieldAccess>().base()->is<VariableReference>()) {
84*c8dee2aaSAndroid Build Coastguard Worker argBase =
85*c8dee2aaSAndroid Build Coastguard Worker arg.as<FieldAccess>().base()->as<VariableReference>().variable();
86*c8dee2aaSAndroid Build Coastguard Worker } else {
87*c8dee2aaSAndroid Build Coastguard Worker continue;
88*c8dee2aaSAndroid Build Coastguard Worker }
89*c8dee2aaSAndroid Build Coastguard Worker SkASSERT(argBase);
90*c8dee2aaSAndroid Build Coastguard Worker
91*c8dee2aaSAndroid Build Coastguard Worker const Variable* param = decl.parameters()[i];
92*c8dee2aaSAndroid Build Coastguard Worker
93*c8dee2aaSAndroid Build Coastguard Worker // Check that this parameter fits the criteria to create a specialization.
94*c8dee2aaSAndroid Build Coastguard Worker if (!fParameterMatchesFn(*param)) {
95*c8dee2aaSAndroid Build Coastguard Worker continue;
96*c8dee2aaSAndroid Build Coastguard Worker }
97*c8dee2aaSAndroid Build Coastguard Worker
98*c8dee2aaSAndroid Build Coastguard Worker if (argBase->storage() == Variable::Storage::kGlobal) {
99*c8dee2aaSAndroid Build Coastguard Worker specialization[param] = &arg;
100*c8dee2aaSAndroid Build Coastguard Worker } else if (argBase->storage() == Variable::Storage::kParameter) {
101*c8dee2aaSAndroid Build Coastguard Worker const Expression** uniformExpr =
102*c8dee2aaSAndroid Build Coastguard Worker fInheritedSpecializations.find(argBase);
103*c8dee2aaSAndroid Build Coastguard Worker SkASSERT(uniformExpr);
104*c8dee2aaSAndroid Build Coastguard Worker
105*c8dee2aaSAndroid Build Coastguard Worker specialization[param] = *uniformExpr;
106*c8dee2aaSAndroid Build Coastguard Worker } else {
107*c8dee2aaSAndroid Build Coastguard Worker // TODO(b/353532475): Report an error instead of aborting.
108*c8dee2aaSAndroid Build Coastguard Worker SK_ABORT("Specialization requires a uniform or parameter variable");
109*c8dee2aaSAndroid Build Coastguard Worker }
110*c8dee2aaSAndroid Build Coastguard Worker }
111*c8dee2aaSAndroid Build Coastguard Worker
112*c8dee2aaSAndroid Build Coastguard Worker // Only create a specialization for this function if there are
113*c8dee2aaSAndroid Build Coastguard Worker // variables to specialize on.
114*c8dee2aaSAndroid Build Coastguard Worker if (specialization.count() > 0) {
115*c8dee2aaSAndroid Build Coastguard Worker Specializations& specializations = fSpecializationMap[&decl];
116*c8dee2aaSAndroid Build Coastguard Worker SpecializedCallKey callKey{call.stablePointer(),
117*c8dee2aaSAndroid Build Coastguard Worker fInheritedSpecializationIndex};
118*c8dee2aaSAndroid Build Coastguard Worker
119*c8dee2aaSAndroid Build Coastguard Worker for (int i = 0; i < specializations.size(); i++) {
120*c8dee2aaSAndroid Build Coastguard Worker const SpecializedParameters& entry = specializations[i];
121*c8dee2aaSAndroid Build Coastguard Worker if (parameter_mappings_are_equal(specialization, entry)) {
122*c8dee2aaSAndroid Build Coastguard Worker // This specialization has already been tracked.
123*c8dee2aaSAndroid Build Coastguard Worker fSpecializedCallMap[callKey] = i;
124*c8dee2aaSAndroid Build Coastguard Worker return INHERITED::visitExpression(expr);
125*c8dee2aaSAndroid Build Coastguard Worker }
126*c8dee2aaSAndroid Build Coastguard Worker }
127*c8dee2aaSAndroid Build Coastguard Worker
128*c8dee2aaSAndroid Build Coastguard Worker // Set the index to the corresponding specialization this function call
129*c8dee2aaSAndroid Build Coastguard Worker // requires, also tracking the inherited specialization this function
130*c8dee2aaSAndroid Build Coastguard Worker // call is in so the right specialized function can be called.
131*c8dee2aaSAndroid Build Coastguard Worker SpecializationIndex specializationIndex = specializations.size();
132*c8dee2aaSAndroid Build Coastguard Worker fSpecializedCallMap[callKey] = specializationIndex;
133*c8dee2aaSAndroid Build Coastguard Worker specializations.push_back(specialization);
134*c8dee2aaSAndroid Build Coastguard Worker
135*c8dee2aaSAndroid Build Coastguard Worker // We swap so we don't lose when our last inherited specializations were
136*c8dee2aaSAndroid Build Coastguard Worker // once we are done traversing this specific specialization.
137*c8dee2aaSAndroid Build Coastguard Worker fInheritedSpecializations.swap(specialization);
138*c8dee2aaSAndroid Build Coastguard Worker std::swap(fInheritedSpecializationIndex, specializationIndex);
139*c8dee2aaSAndroid Build Coastguard Worker
140*c8dee2aaSAndroid Build Coastguard Worker this->visitProgramElement(*decl.definition());
141*c8dee2aaSAndroid Build Coastguard Worker
142*c8dee2aaSAndroid Build Coastguard Worker std::swap(fInheritedSpecializationIndex, specializationIndex);
143*c8dee2aaSAndroid Build Coastguard Worker fInheritedSpecializations.swap(specialization);
144*c8dee2aaSAndroid Build Coastguard Worker } else {
145*c8dee2aaSAndroid Build Coastguard Worker // The function being called isn't specialized, but we need to walk the
146*c8dee2aaSAndroid Build Coastguard Worker // entire call graph or we may miss a specialized call entirely. Since
147*c8dee2aaSAndroid Build Coastguard Worker // nothing is specialized, it is safe to skip over repeated traversals.
148*c8dee2aaSAndroid Build Coastguard Worker if (!fVisitedFunctions.find(&decl)) {
149*c8dee2aaSAndroid Build Coastguard Worker fVisitedFunctions.add(&decl);
150*c8dee2aaSAndroid Build Coastguard Worker this->visitProgramElement(*decl.definition());
151*c8dee2aaSAndroid Build Coastguard Worker }
152*c8dee2aaSAndroid Build Coastguard Worker }
153*c8dee2aaSAndroid Build Coastguard Worker }
154*c8dee2aaSAndroid Build Coastguard Worker }
155*c8dee2aaSAndroid Build Coastguard Worker return INHERITED::visitExpression(expr);
156*c8dee2aaSAndroid Build Coastguard Worker }
157*c8dee2aaSAndroid Build Coastguard Worker
158*c8dee2aaSAndroid Build Coastguard Worker private:
159*c8dee2aaSAndroid Build Coastguard Worker SpecializationMap& fSpecializationMap;
160*c8dee2aaSAndroid Build Coastguard Worker SpecializedCallMap& fSpecializedCallMap;
161*c8dee2aaSAndroid Build Coastguard Worker const ParameterMatchesFn& fParameterMatchesFn;
162*c8dee2aaSAndroid Build Coastguard Worker THashSet<const FunctionDeclaration*> fVisitedFunctions;
163*c8dee2aaSAndroid Build Coastguard Worker
164*c8dee2aaSAndroid Build Coastguard Worker SpecializedParameters fInheritedSpecializations;
165*c8dee2aaSAndroid Build Coastguard Worker SpecializationIndex fInheritedSpecializationIndex = kUnspecialized;
166*c8dee2aaSAndroid Build Coastguard Worker };
167*c8dee2aaSAndroid Build Coastguard Worker
168*c8dee2aaSAndroid Build Coastguard Worker for (const ProgramElement* elem : program.elements()) {
169*c8dee2aaSAndroid Build Coastguard Worker if (elem->is<FunctionDefinition>()) {
170*c8dee2aaSAndroid Build Coastguard Worker const FunctionDeclaration& decl = elem->as<FunctionDefinition>().declaration();
171*c8dee2aaSAndroid Build Coastguard Worker
172*c8dee2aaSAndroid Build Coastguard Worker if (decl.isMain()) {
173*c8dee2aaSAndroid Build Coastguard Worker // Visit through the program call stack and aggregates any necessary
174*c8dee2aaSAndroid Build Coastguard Worker // function specializations.
175*c8dee2aaSAndroid Build Coastguard Worker Searcher(*info, parameterMatchesFn).visitProgramElement(*elem);
176*c8dee2aaSAndroid Build Coastguard Worker continue;
177*c8dee2aaSAndroid Build Coastguard Worker }
178*c8dee2aaSAndroid Build Coastguard Worker
179*c8dee2aaSAndroid Build Coastguard Worker // Look for any function parameter which needs specialization.
180*c8dee2aaSAndroid Build Coastguard Worker for (const Variable* param : decl.parameters()) {
181*c8dee2aaSAndroid Build Coastguard Worker if (parameterMatchesFn(*param)) {
182*c8dee2aaSAndroid Build Coastguard Worker // We found a function that requires specialization. Ensure that this function
183*c8dee2aaSAndroid Build Coastguard Worker // ends up in the specialization map, whether or not it is reachable from
184*c8dee2aaSAndroid Build Coastguard Worker // main().
185*c8dee2aaSAndroid Build Coastguard Worker //
186*c8dee2aaSAndroid Build Coastguard Worker // Doing this here allows unreachable specialized functions to be discarded,
187*c8dee2aaSAndroid Build Coastguard Worker // because it will be in the specialization map with an array of zero necessary
188*c8dee2aaSAndroid Build Coastguard Worker // specializations to emit. If we didn't add this function to the specialization
189*c8dee2aaSAndroid Build Coastguard Worker // map at all, the code generator would try to emit it without applying
190*c8dee2aaSAndroid Build Coastguard Worker // specializations, and generally this would lead to invalid code.
191*c8dee2aaSAndroid Build Coastguard Worker info->fSpecializationMap[&decl];
192*c8dee2aaSAndroid Build Coastguard Worker break;
193*c8dee2aaSAndroid Build Coastguard Worker }
194*c8dee2aaSAndroid Build Coastguard Worker }
195*c8dee2aaSAndroid Build Coastguard Worker }
196*c8dee2aaSAndroid Build Coastguard Worker }
197*c8dee2aaSAndroid Build Coastguard Worker }
198*c8dee2aaSAndroid Build Coastguard Worker
FindSpecializationIndexForCall(const FunctionCall & call,const SpecializationInfo & info,SpecializationIndex parentSpecializationIndex)199*c8dee2aaSAndroid Build Coastguard Worker SpecializationIndex FindSpecializationIndexForCall(const FunctionCall& call,
200*c8dee2aaSAndroid Build Coastguard Worker const SpecializationInfo& info,
201*c8dee2aaSAndroid Build Coastguard Worker SpecializationIndex parentSpecializationIndex) {
202*c8dee2aaSAndroid Build Coastguard Worker SpecializedCallKey callKey{call.stablePointer(), parentSpecializationIndex};
203*c8dee2aaSAndroid Build Coastguard Worker SpecializationIndex* foundIndex = info.fSpecializedCallMap.find(callKey);
204*c8dee2aaSAndroid Build Coastguard Worker return foundIndex ? *foundIndex : kUnspecialized;
205*c8dee2aaSAndroid Build Coastguard Worker }
206*c8dee2aaSAndroid Build Coastguard Worker
FindSpecializedParametersForFunction(const FunctionDeclaration & func,const SpecializationInfo & info)207*c8dee2aaSAndroid Build Coastguard Worker SkBitSet FindSpecializedParametersForFunction(const FunctionDeclaration& func,
208*c8dee2aaSAndroid Build Coastguard Worker const SpecializationInfo& info) {
209*c8dee2aaSAndroid Build Coastguard Worker SkBitSet result(func.parameters().size());
210*c8dee2aaSAndroid Build Coastguard Worker if (const Specializations* specializations = info.fSpecializationMap.find(&func)) {
211*c8dee2aaSAndroid Build Coastguard Worker const Analysis::SpecializedParameters& specializedParams = specializations->front();
212*c8dee2aaSAndroid Build Coastguard Worker const SkSpan<Variable* const> funcParams = func.parameters();
213*c8dee2aaSAndroid Build Coastguard Worker
214*c8dee2aaSAndroid Build Coastguard Worker for (size_t index = 0; index < funcParams.size(); ++index) {
215*c8dee2aaSAndroid Build Coastguard Worker if (specializedParams.find(funcParams[index])) {
216*c8dee2aaSAndroid Build Coastguard Worker result.set(index);
217*c8dee2aaSAndroid Build Coastguard Worker }
218*c8dee2aaSAndroid Build Coastguard Worker }
219*c8dee2aaSAndroid Build Coastguard Worker }
220*c8dee2aaSAndroid Build Coastguard Worker
221*c8dee2aaSAndroid Build Coastguard Worker return result;
222*c8dee2aaSAndroid Build Coastguard Worker }
223*c8dee2aaSAndroid Build Coastguard Worker
GetParameterMappingsForFunction(const FunctionDeclaration & func,const SpecializationInfo & info,SpecializationIndex specializationIndex,const ParameterMappingCallback & callback)224*c8dee2aaSAndroid Build Coastguard Worker void GetParameterMappingsForFunction(const FunctionDeclaration& func,
225*c8dee2aaSAndroid Build Coastguard Worker const SpecializationInfo& info,
226*c8dee2aaSAndroid Build Coastguard Worker SpecializationIndex specializationIndex,
227*c8dee2aaSAndroid Build Coastguard Worker const ParameterMappingCallback& callback) {
228*c8dee2aaSAndroid Build Coastguard Worker if (specializationIndex != Analysis::kUnspecialized) {
229*c8dee2aaSAndroid Build Coastguard Worker if (const Specializations* specializations = info.fSpecializationMap.find(&func)) {
230*c8dee2aaSAndroid Build Coastguard Worker const Analysis::SpecializedParameters& specializedParams =
231*c8dee2aaSAndroid Build Coastguard Worker specializations->at(specializationIndex);
232*c8dee2aaSAndroid Build Coastguard Worker const SkSpan<Variable* const> funcParams = func.parameters();
233*c8dee2aaSAndroid Build Coastguard Worker
234*c8dee2aaSAndroid Build Coastguard Worker for (size_t index = 0; index < funcParams.size(); ++index) {
235*c8dee2aaSAndroid Build Coastguard Worker const Variable* param = funcParams[index];
236*c8dee2aaSAndroid Build Coastguard Worker if (const Expression** expr = specializedParams.find(param)) {
237*c8dee2aaSAndroid Build Coastguard Worker callback(index, param, *expr);
238*c8dee2aaSAndroid Build Coastguard Worker }
239*c8dee2aaSAndroid Build Coastguard Worker }
240*c8dee2aaSAndroid Build Coastguard Worker }
241*c8dee2aaSAndroid Build Coastguard Worker }
242*c8dee2aaSAndroid Build Coastguard Worker }
243*c8dee2aaSAndroid Build Coastguard Worker
244*c8dee2aaSAndroid Build Coastguard Worker } // namespace SkSL::Analysis
245