xref: /aosp_15_r20/external/skia/src/sksl/analysis/SkSLSpecialization.cpp (revision c8dee2aa9b3f27cf6c858bd81872bdeb2c07ed17)
1*c8dee2aaSAndroid Build Coastguard Worker /*
2*c8dee2aaSAndroid Build Coastguard Worker  * Copyright 2024 Google LLC
3*c8dee2aaSAndroid Build Coastguard Worker  *
4*c8dee2aaSAndroid Build Coastguard Worker  * Use of this source code is governed by a BSD-style license that can be
5*c8dee2aaSAndroid Build Coastguard Worker  * found in the LICENSE file.
6*c8dee2aaSAndroid Build Coastguard Worker  */
7*c8dee2aaSAndroid Build Coastguard Worker 
8*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/analysis/SkSLSpecialization.h"
9*c8dee2aaSAndroid Build Coastguard Worker 
10*c8dee2aaSAndroid Build Coastguard Worker #include "include/private/base/SkAssert.h"
11*c8dee2aaSAndroid Build Coastguard Worker #include "include/private/base/SkSpan_impl.h"
12*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/SkSLAnalysis.h"
13*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/SkSLDefines.h"
14*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/analysis/SkSLProgramVisitor.h"
15*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/ir/SkSLExpression.h"
16*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/ir/SkSLFieldAccess.h"
17*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/ir/SkSLFunctionCall.h"
18*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/ir/SkSLFunctionDeclaration.h"
19*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/ir/SkSLFunctionDefinition.h"
20*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/ir/SkSLProgram.h"
21*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/ir/SkSLProgramElement.h"
22*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/ir/SkSLVariable.h"
23*c8dee2aaSAndroid Build Coastguard Worker #include "src/sksl/ir/SkSLVariableReference.h"
24*c8dee2aaSAndroid Build Coastguard Worker 
25*c8dee2aaSAndroid Build Coastguard Worker #include <algorithm>
26*c8dee2aaSAndroid Build Coastguard Worker #include <memory>
27*c8dee2aaSAndroid Build Coastguard Worker 
28*c8dee2aaSAndroid Build Coastguard Worker using namespace skia_private;
29*c8dee2aaSAndroid Build Coastguard Worker 
30*c8dee2aaSAndroid Build Coastguard Worker namespace SkSL::Analysis {
31*c8dee2aaSAndroid Build Coastguard Worker 
parameter_mappings_are_equal(const SpecializedParameters & left,const SpecializedParameters & right)32*c8dee2aaSAndroid Build Coastguard Worker static bool parameter_mappings_are_equal(const SpecializedParameters& left,
33*c8dee2aaSAndroid Build Coastguard Worker                                          const SpecializedParameters& right) {
34*c8dee2aaSAndroid Build Coastguard Worker     if (left.count() != right.count()) {
35*c8dee2aaSAndroid Build Coastguard Worker         return false;
36*c8dee2aaSAndroid Build Coastguard Worker     }
37*c8dee2aaSAndroid Build Coastguard Worker     for (const auto& [key, leftExpr] : left) {
38*c8dee2aaSAndroid Build Coastguard Worker         const Expression** rightExpr = right.find(key);
39*c8dee2aaSAndroid Build Coastguard Worker         if (!rightExpr) {
40*c8dee2aaSAndroid Build Coastguard Worker             return false;
41*c8dee2aaSAndroid Build Coastguard Worker         }
42*c8dee2aaSAndroid Build Coastguard Worker         if (!Analysis::IsSameExpressionTree(*leftExpr, **rightExpr)) {
43*c8dee2aaSAndroid Build Coastguard Worker             return false;
44*c8dee2aaSAndroid Build Coastguard Worker         }
45*c8dee2aaSAndroid Build Coastguard Worker     }
46*c8dee2aaSAndroid Build Coastguard Worker     return true;
47*c8dee2aaSAndroid Build Coastguard Worker }
48*c8dee2aaSAndroid Build Coastguard Worker 
FindFunctionsToSpecialize(const Program & program,SpecializationInfo * info,const ParameterMatchesFn & parameterMatchesFn)49*c8dee2aaSAndroid Build Coastguard Worker void FindFunctionsToSpecialize(const Program& program,
50*c8dee2aaSAndroid Build Coastguard Worker                                SpecializationInfo* info,
51*c8dee2aaSAndroid Build Coastguard Worker                                const ParameterMatchesFn& parameterMatchesFn) {
52*c8dee2aaSAndroid Build Coastguard Worker     class Searcher : public ProgramVisitor {
53*c8dee2aaSAndroid Build Coastguard Worker     public:
54*c8dee2aaSAndroid Build Coastguard Worker         using ProgramVisitor::visitProgramElement;
55*c8dee2aaSAndroid Build Coastguard Worker         using INHERITED = ProgramVisitor;
56*c8dee2aaSAndroid Build Coastguard Worker 
57*c8dee2aaSAndroid Build Coastguard Worker         Searcher(SpecializationInfo& info, const ParameterMatchesFn& parameterMatchesFn)
58*c8dee2aaSAndroid Build Coastguard Worker                 : fSpecializationMap(info.fSpecializationMap)
59*c8dee2aaSAndroid Build Coastguard Worker                 , fSpecializedCallMap(info.fSpecializedCallMap)
60*c8dee2aaSAndroid Build Coastguard Worker                 , fParameterMatchesFn(parameterMatchesFn) {}
61*c8dee2aaSAndroid Build Coastguard Worker 
62*c8dee2aaSAndroid Build Coastguard Worker         bool visitExpression(const Expression& expr) override {
63*c8dee2aaSAndroid Build Coastguard Worker             if (expr.is<FunctionCall>()) {
64*c8dee2aaSAndroid Build Coastguard Worker                 const FunctionCall& call = expr.as<FunctionCall>();
65*c8dee2aaSAndroid Build Coastguard Worker                 const FunctionDeclaration& decl = call.function();
66*c8dee2aaSAndroid Build Coastguard Worker 
67*c8dee2aaSAndroid Build Coastguard Worker                 if (!decl.isIntrinsic()) {
68*c8dee2aaSAndroid Build Coastguard Worker                     SpecializedParameters specialization;
69*c8dee2aaSAndroid Build Coastguard Worker 
70*c8dee2aaSAndroid Build Coastguard Worker                     const int numParams = decl.parameters().size();
71*c8dee2aaSAndroid Build Coastguard Worker                     SkASSERT(call.arguments().size() == numParams);
72*c8dee2aaSAndroid Build Coastguard Worker 
73*c8dee2aaSAndroid Build Coastguard Worker                     for (int i = 0; i < numParams; i++) {
74*c8dee2aaSAndroid Build Coastguard Worker                         const Expression& arg = *call.arguments()[i];
75*c8dee2aaSAndroid Build Coastguard Worker 
76*c8dee2aaSAndroid Build Coastguard Worker                         // Specializations can only be made on arguments that are not complex
77*c8dee2aaSAndroid Build Coastguard Worker                         // expressions but only a variable reference or field access since these
78*c8dee2aaSAndroid Build Coastguard Worker                         // references will be inlined in the generated specialized functions.
79*c8dee2aaSAndroid Build Coastguard Worker                         const Variable* argBase = nullptr;
80*c8dee2aaSAndroid Build Coastguard Worker                         if (arg.is<VariableReference>()) {
81*c8dee2aaSAndroid Build Coastguard Worker                             argBase = arg.as<VariableReference>().variable();
82*c8dee2aaSAndroid Build Coastguard Worker                         } else if (arg.is<FieldAccess>() &&
83*c8dee2aaSAndroid Build Coastguard Worker                                    arg.as<FieldAccess>().base()->is<VariableReference>()) {
84*c8dee2aaSAndroid Build Coastguard Worker                             argBase =
85*c8dee2aaSAndroid Build Coastguard Worker                                 arg.as<FieldAccess>().base()->as<VariableReference>().variable();
86*c8dee2aaSAndroid Build Coastguard Worker                         } else {
87*c8dee2aaSAndroid Build Coastguard Worker                             continue;
88*c8dee2aaSAndroid Build Coastguard Worker                         }
89*c8dee2aaSAndroid Build Coastguard Worker                         SkASSERT(argBase);
90*c8dee2aaSAndroid Build Coastguard Worker 
91*c8dee2aaSAndroid Build Coastguard Worker                         const Variable* param = decl.parameters()[i];
92*c8dee2aaSAndroid Build Coastguard Worker 
93*c8dee2aaSAndroid Build Coastguard Worker                         // Check that this parameter fits the criteria to create a specialization.
94*c8dee2aaSAndroid Build Coastguard Worker                         if (!fParameterMatchesFn(*param)) {
95*c8dee2aaSAndroid Build Coastguard Worker                             continue;
96*c8dee2aaSAndroid Build Coastguard Worker                         }
97*c8dee2aaSAndroid Build Coastguard Worker 
98*c8dee2aaSAndroid Build Coastguard Worker                         if (argBase->storage() == Variable::Storage::kGlobal) {
99*c8dee2aaSAndroid Build Coastguard Worker                             specialization[param] = &arg;
100*c8dee2aaSAndroid Build Coastguard Worker                         } else if (argBase->storage() == Variable::Storage::kParameter) {
101*c8dee2aaSAndroid Build Coastguard Worker                             const Expression** uniformExpr =
102*c8dee2aaSAndroid Build Coastguard Worker                                 fInheritedSpecializations.find(argBase);
103*c8dee2aaSAndroid Build Coastguard Worker                             SkASSERT(uniformExpr);
104*c8dee2aaSAndroid Build Coastguard Worker 
105*c8dee2aaSAndroid Build Coastguard Worker                             specialization[param] = *uniformExpr;
106*c8dee2aaSAndroid Build Coastguard Worker                         } else {
107*c8dee2aaSAndroid Build Coastguard Worker                             // TODO(b/353532475): Report an error instead of aborting.
108*c8dee2aaSAndroid Build Coastguard Worker                             SK_ABORT("Specialization requires a uniform or parameter variable");
109*c8dee2aaSAndroid Build Coastguard Worker                         }
110*c8dee2aaSAndroid Build Coastguard Worker                     }
111*c8dee2aaSAndroid Build Coastguard Worker 
112*c8dee2aaSAndroid Build Coastguard Worker                     // Only create a specialization for this function if there are
113*c8dee2aaSAndroid Build Coastguard Worker                     // variables to specialize on.
114*c8dee2aaSAndroid Build Coastguard Worker                     if (specialization.count() > 0) {
115*c8dee2aaSAndroid Build Coastguard Worker                         Specializations& specializations = fSpecializationMap[&decl];
116*c8dee2aaSAndroid Build Coastguard Worker                         SpecializedCallKey callKey{call.stablePointer(),
117*c8dee2aaSAndroid Build Coastguard Worker                                                    fInheritedSpecializationIndex};
118*c8dee2aaSAndroid Build Coastguard Worker 
119*c8dee2aaSAndroid Build Coastguard Worker                         for (int i = 0; i < specializations.size(); i++) {
120*c8dee2aaSAndroid Build Coastguard Worker                             const SpecializedParameters& entry = specializations[i];
121*c8dee2aaSAndroid Build Coastguard Worker                             if (parameter_mappings_are_equal(specialization, entry)) {
122*c8dee2aaSAndroid Build Coastguard Worker                                 // This specialization has already been tracked.
123*c8dee2aaSAndroid Build Coastguard Worker                                 fSpecializedCallMap[callKey] = i;
124*c8dee2aaSAndroid Build Coastguard Worker                                 return INHERITED::visitExpression(expr);
125*c8dee2aaSAndroid Build Coastguard Worker                             }
126*c8dee2aaSAndroid Build Coastguard Worker                         }
127*c8dee2aaSAndroid Build Coastguard Worker 
128*c8dee2aaSAndroid Build Coastguard Worker                         // Set the index to the corresponding specialization this function call
129*c8dee2aaSAndroid Build Coastguard Worker                         // requires, also tracking the inherited specialization this function
130*c8dee2aaSAndroid Build Coastguard Worker                         // call is in so the right specialized function can be called.
131*c8dee2aaSAndroid Build Coastguard Worker                         SpecializationIndex specializationIndex = specializations.size();
132*c8dee2aaSAndroid Build Coastguard Worker                         fSpecializedCallMap[callKey] = specializationIndex;
133*c8dee2aaSAndroid Build Coastguard Worker                         specializations.push_back(specialization);
134*c8dee2aaSAndroid Build Coastguard Worker 
135*c8dee2aaSAndroid Build Coastguard Worker                         // We swap so we don't lose when our last inherited specializations were
136*c8dee2aaSAndroid Build Coastguard Worker                         // once we are done traversing this specific specialization.
137*c8dee2aaSAndroid Build Coastguard Worker                         fInheritedSpecializations.swap(specialization);
138*c8dee2aaSAndroid Build Coastguard Worker                         std::swap(fInheritedSpecializationIndex, specializationIndex);
139*c8dee2aaSAndroid Build Coastguard Worker 
140*c8dee2aaSAndroid Build Coastguard Worker                         this->visitProgramElement(*decl.definition());
141*c8dee2aaSAndroid Build Coastguard Worker 
142*c8dee2aaSAndroid Build Coastguard Worker                         std::swap(fInheritedSpecializationIndex, specializationIndex);
143*c8dee2aaSAndroid Build Coastguard Worker                         fInheritedSpecializations.swap(specialization);
144*c8dee2aaSAndroid Build Coastguard Worker                     } else {
145*c8dee2aaSAndroid Build Coastguard Worker                         // The function being called isn't specialized, but we need to walk the
146*c8dee2aaSAndroid Build Coastguard Worker                         // entire call graph or we may miss a specialized call entirely. Since
147*c8dee2aaSAndroid Build Coastguard Worker                         // nothing is specialized, it is safe to skip over repeated traversals.
148*c8dee2aaSAndroid Build Coastguard Worker                         if (!fVisitedFunctions.find(&decl)) {
149*c8dee2aaSAndroid Build Coastguard Worker                             fVisitedFunctions.add(&decl);
150*c8dee2aaSAndroid Build Coastguard Worker                             this->visitProgramElement(*decl.definition());
151*c8dee2aaSAndroid Build Coastguard Worker                         }
152*c8dee2aaSAndroid Build Coastguard Worker                     }
153*c8dee2aaSAndroid Build Coastguard Worker                 }
154*c8dee2aaSAndroid Build Coastguard Worker             }
155*c8dee2aaSAndroid Build Coastguard Worker             return INHERITED::visitExpression(expr);
156*c8dee2aaSAndroid Build Coastguard Worker         }
157*c8dee2aaSAndroid Build Coastguard Worker 
158*c8dee2aaSAndroid Build Coastguard Worker     private:
159*c8dee2aaSAndroid Build Coastguard Worker         SpecializationMap& fSpecializationMap;
160*c8dee2aaSAndroid Build Coastguard Worker         SpecializedCallMap& fSpecializedCallMap;
161*c8dee2aaSAndroid Build Coastguard Worker         const ParameterMatchesFn& fParameterMatchesFn;
162*c8dee2aaSAndroid Build Coastguard Worker         THashSet<const FunctionDeclaration*> fVisitedFunctions;
163*c8dee2aaSAndroid Build Coastguard Worker 
164*c8dee2aaSAndroid Build Coastguard Worker         SpecializedParameters fInheritedSpecializations;
165*c8dee2aaSAndroid Build Coastguard Worker         SpecializationIndex fInheritedSpecializationIndex = kUnspecialized;
166*c8dee2aaSAndroid Build Coastguard Worker     };
167*c8dee2aaSAndroid Build Coastguard Worker 
168*c8dee2aaSAndroid Build Coastguard Worker     for (const ProgramElement* elem : program.elements()) {
169*c8dee2aaSAndroid Build Coastguard Worker         if (elem->is<FunctionDefinition>()) {
170*c8dee2aaSAndroid Build Coastguard Worker             const FunctionDeclaration& decl = elem->as<FunctionDefinition>().declaration();
171*c8dee2aaSAndroid Build Coastguard Worker 
172*c8dee2aaSAndroid Build Coastguard Worker             if (decl.isMain()) {
173*c8dee2aaSAndroid Build Coastguard Worker                 // Visit through the program call stack and aggregates any necessary
174*c8dee2aaSAndroid Build Coastguard Worker                 // function specializations.
175*c8dee2aaSAndroid Build Coastguard Worker                 Searcher(*info, parameterMatchesFn).visitProgramElement(*elem);
176*c8dee2aaSAndroid Build Coastguard Worker                 continue;
177*c8dee2aaSAndroid Build Coastguard Worker             }
178*c8dee2aaSAndroid Build Coastguard Worker 
179*c8dee2aaSAndroid Build Coastguard Worker             // Look for any function parameter which needs specialization.
180*c8dee2aaSAndroid Build Coastguard Worker             for (const Variable* param : decl.parameters()) {
181*c8dee2aaSAndroid Build Coastguard Worker                 if (parameterMatchesFn(*param)) {
182*c8dee2aaSAndroid Build Coastguard Worker                     // We found a function that requires specialization. Ensure that this function
183*c8dee2aaSAndroid Build Coastguard Worker                     // ends up in the specialization map, whether or not it is reachable from
184*c8dee2aaSAndroid Build Coastguard Worker                     // main().
185*c8dee2aaSAndroid Build Coastguard Worker                     //
186*c8dee2aaSAndroid Build Coastguard Worker                     // Doing this here allows unreachable specialized functions to be discarded,
187*c8dee2aaSAndroid Build Coastguard Worker                     // because it will be in the specialization map with an array of zero necessary
188*c8dee2aaSAndroid Build Coastguard Worker                     // specializations to emit. If we didn't add this function to the specialization
189*c8dee2aaSAndroid Build Coastguard Worker                     // map at all, the code generator would try to emit it without applying
190*c8dee2aaSAndroid Build Coastguard Worker                     // specializations, and generally this would lead to invalid code.
191*c8dee2aaSAndroid Build Coastguard Worker                     info->fSpecializationMap[&decl];
192*c8dee2aaSAndroid Build Coastguard Worker                     break;
193*c8dee2aaSAndroid Build Coastguard Worker                 }
194*c8dee2aaSAndroid Build Coastguard Worker             }
195*c8dee2aaSAndroid Build Coastguard Worker         }
196*c8dee2aaSAndroid Build Coastguard Worker     }
197*c8dee2aaSAndroid Build Coastguard Worker }
198*c8dee2aaSAndroid Build Coastguard Worker 
FindSpecializationIndexForCall(const FunctionCall & call,const SpecializationInfo & info,SpecializationIndex parentSpecializationIndex)199*c8dee2aaSAndroid Build Coastguard Worker SpecializationIndex FindSpecializationIndexForCall(const FunctionCall& call,
200*c8dee2aaSAndroid Build Coastguard Worker                                                    const SpecializationInfo& info,
201*c8dee2aaSAndroid Build Coastguard Worker                                                    SpecializationIndex parentSpecializationIndex) {
202*c8dee2aaSAndroid Build Coastguard Worker     SpecializedCallKey callKey{call.stablePointer(), parentSpecializationIndex};
203*c8dee2aaSAndroid Build Coastguard Worker     SpecializationIndex* foundIndex = info.fSpecializedCallMap.find(callKey);
204*c8dee2aaSAndroid Build Coastguard Worker     return foundIndex ? *foundIndex : kUnspecialized;
205*c8dee2aaSAndroid Build Coastguard Worker }
206*c8dee2aaSAndroid Build Coastguard Worker 
FindSpecializedParametersForFunction(const FunctionDeclaration & func,const SpecializationInfo & info)207*c8dee2aaSAndroid Build Coastguard Worker SkBitSet FindSpecializedParametersForFunction(const FunctionDeclaration& func,
208*c8dee2aaSAndroid Build Coastguard Worker                                               const SpecializationInfo& info) {
209*c8dee2aaSAndroid Build Coastguard Worker     SkBitSet result(func.parameters().size());
210*c8dee2aaSAndroid Build Coastguard Worker     if (const Specializations* specializations = info.fSpecializationMap.find(&func)) {
211*c8dee2aaSAndroid Build Coastguard Worker         const Analysis::SpecializedParameters& specializedParams = specializations->front();
212*c8dee2aaSAndroid Build Coastguard Worker         const SkSpan<Variable* const> funcParams = func.parameters();
213*c8dee2aaSAndroid Build Coastguard Worker 
214*c8dee2aaSAndroid Build Coastguard Worker         for (size_t index = 0; index < funcParams.size(); ++index) {
215*c8dee2aaSAndroid Build Coastguard Worker             if (specializedParams.find(funcParams[index])) {
216*c8dee2aaSAndroid Build Coastguard Worker                 result.set(index);
217*c8dee2aaSAndroid Build Coastguard Worker             }
218*c8dee2aaSAndroid Build Coastguard Worker         }
219*c8dee2aaSAndroid Build Coastguard Worker     }
220*c8dee2aaSAndroid Build Coastguard Worker 
221*c8dee2aaSAndroid Build Coastguard Worker     return result;
222*c8dee2aaSAndroid Build Coastguard Worker }
223*c8dee2aaSAndroid Build Coastguard Worker 
GetParameterMappingsForFunction(const FunctionDeclaration & func,const SpecializationInfo & info,SpecializationIndex specializationIndex,const ParameterMappingCallback & callback)224*c8dee2aaSAndroid Build Coastguard Worker void GetParameterMappingsForFunction(const FunctionDeclaration& func,
225*c8dee2aaSAndroid Build Coastguard Worker                                      const SpecializationInfo& info,
226*c8dee2aaSAndroid Build Coastguard Worker                                      SpecializationIndex specializationIndex,
227*c8dee2aaSAndroid Build Coastguard Worker                                      const ParameterMappingCallback& callback) {
228*c8dee2aaSAndroid Build Coastguard Worker     if (specializationIndex != Analysis::kUnspecialized) {
229*c8dee2aaSAndroid Build Coastguard Worker         if (const Specializations* specializations = info.fSpecializationMap.find(&func)) {
230*c8dee2aaSAndroid Build Coastguard Worker             const Analysis::SpecializedParameters& specializedParams =
231*c8dee2aaSAndroid Build Coastguard Worker                     specializations->at(specializationIndex);
232*c8dee2aaSAndroid Build Coastguard Worker             const SkSpan<Variable* const> funcParams = func.parameters();
233*c8dee2aaSAndroid Build Coastguard Worker 
234*c8dee2aaSAndroid Build Coastguard Worker             for (size_t index = 0; index < funcParams.size(); ++index) {
235*c8dee2aaSAndroid Build Coastguard Worker                 const Variable* param = funcParams[index];
236*c8dee2aaSAndroid Build Coastguard Worker                 if (const Expression** expr = specializedParams.find(param)) {
237*c8dee2aaSAndroid Build Coastguard Worker                     callback(index, param, *expr);
238*c8dee2aaSAndroid Build Coastguard Worker                 }
239*c8dee2aaSAndroid Build Coastguard Worker             }
240*c8dee2aaSAndroid Build Coastguard Worker         }
241*c8dee2aaSAndroid Build Coastguard Worker     }
242*c8dee2aaSAndroid Build Coastguard Worker }
243*c8dee2aaSAndroid Build Coastguard Worker 
244*c8dee2aaSAndroid Build Coastguard Worker }  // namespace SkSL::Analysis
245