xref: /aosp_15_r20/external/skia/src/sksl/analysis/SkSLSpecialization.cpp (revision c8dee2aa9b3f27cf6c858bd81872bdeb2c07ed17)
1 /*
2  * Copyright 2024 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/analysis/SkSLSpecialization.h"
9 
10 #include "include/private/base/SkAssert.h"
11 #include "include/private/base/SkSpan_impl.h"
12 #include "src/sksl/SkSLAnalysis.h"
13 #include "src/sksl/SkSLDefines.h"
14 #include "src/sksl/analysis/SkSLProgramVisitor.h"
15 #include "src/sksl/ir/SkSLExpression.h"
16 #include "src/sksl/ir/SkSLFieldAccess.h"
17 #include "src/sksl/ir/SkSLFunctionCall.h"
18 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
19 #include "src/sksl/ir/SkSLFunctionDefinition.h"
20 #include "src/sksl/ir/SkSLProgram.h"
21 #include "src/sksl/ir/SkSLProgramElement.h"
22 #include "src/sksl/ir/SkSLVariable.h"
23 #include "src/sksl/ir/SkSLVariableReference.h"
24 
25 #include <algorithm>
26 #include <memory>
27 
28 using namespace skia_private;
29 
30 namespace SkSL::Analysis {
31 
parameter_mappings_are_equal(const SpecializedParameters & left,const SpecializedParameters & right)32 static bool parameter_mappings_are_equal(const SpecializedParameters& left,
33                                          const SpecializedParameters& right) {
34     if (left.count() != right.count()) {
35         return false;
36     }
37     for (const auto& [key, leftExpr] : left) {
38         const Expression** rightExpr = right.find(key);
39         if (!rightExpr) {
40             return false;
41         }
42         if (!Analysis::IsSameExpressionTree(*leftExpr, **rightExpr)) {
43             return false;
44         }
45     }
46     return true;
47 }
48 
FindFunctionsToSpecialize(const Program & program,SpecializationInfo * info,const ParameterMatchesFn & parameterMatchesFn)49 void FindFunctionsToSpecialize(const Program& program,
50                                SpecializationInfo* info,
51                                const ParameterMatchesFn& parameterMatchesFn) {
52     class Searcher : public ProgramVisitor {
53     public:
54         using ProgramVisitor::visitProgramElement;
55         using INHERITED = ProgramVisitor;
56 
57         Searcher(SpecializationInfo& info, const ParameterMatchesFn& parameterMatchesFn)
58                 : fSpecializationMap(info.fSpecializationMap)
59                 , fSpecializedCallMap(info.fSpecializedCallMap)
60                 , fParameterMatchesFn(parameterMatchesFn) {}
61 
62         bool visitExpression(const Expression& expr) override {
63             if (expr.is<FunctionCall>()) {
64                 const FunctionCall& call = expr.as<FunctionCall>();
65                 const FunctionDeclaration& decl = call.function();
66 
67                 if (!decl.isIntrinsic()) {
68                     SpecializedParameters specialization;
69 
70                     const int numParams = decl.parameters().size();
71                     SkASSERT(call.arguments().size() == numParams);
72 
73                     for (int i = 0; i < numParams; i++) {
74                         const Expression& arg = *call.arguments()[i];
75 
76                         // Specializations can only be made on arguments that are not complex
77                         // expressions but only a variable reference or field access since these
78                         // references will be inlined in the generated specialized functions.
79                         const Variable* argBase = nullptr;
80                         if (arg.is<VariableReference>()) {
81                             argBase = arg.as<VariableReference>().variable();
82                         } else if (arg.is<FieldAccess>() &&
83                                    arg.as<FieldAccess>().base()->is<VariableReference>()) {
84                             argBase =
85                                 arg.as<FieldAccess>().base()->as<VariableReference>().variable();
86                         } else {
87                             continue;
88                         }
89                         SkASSERT(argBase);
90 
91                         const Variable* param = decl.parameters()[i];
92 
93                         // Check that this parameter fits the criteria to create a specialization.
94                         if (!fParameterMatchesFn(*param)) {
95                             continue;
96                         }
97 
98                         if (argBase->storage() == Variable::Storage::kGlobal) {
99                             specialization[param] = &arg;
100                         } else if (argBase->storage() == Variable::Storage::kParameter) {
101                             const Expression** uniformExpr =
102                                 fInheritedSpecializations.find(argBase);
103                             SkASSERT(uniformExpr);
104 
105                             specialization[param] = *uniformExpr;
106                         } else {
107                             // TODO(b/353532475): Report an error instead of aborting.
108                             SK_ABORT("Specialization requires a uniform or parameter variable");
109                         }
110                     }
111 
112                     // Only create a specialization for this function if there are
113                     // variables to specialize on.
114                     if (specialization.count() > 0) {
115                         Specializations& specializations = fSpecializationMap[&decl];
116                         SpecializedCallKey callKey{call.stablePointer(),
117                                                    fInheritedSpecializationIndex};
118 
119                         for (int i = 0; i < specializations.size(); i++) {
120                             const SpecializedParameters& entry = specializations[i];
121                             if (parameter_mappings_are_equal(specialization, entry)) {
122                                 // This specialization has already been tracked.
123                                 fSpecializedCallMap[callKey] = i;
124                                 return INHERITED::visitExpression(expr);
125                             }
126                         }
127 
128                         // Set the index to the corresponding specialization this function call
129                         // requires, also tracking the inherited specialization this function
130                         // call is in so the right specialized function can be called.
131                         SpecializationIndex specializationIndex = specializations.size();
132                         fSpecializedCallMap[callKey] = specializationIndex;
133                         specializations.push_back(specialization);
134 
135                         // We swap so we don't lose when our last inherited specializations were
136                         // once we are done traversing this specific specialization.
137                         fInheritedSpecializations.swap(specialization);
138                         std::swap(fInheritedSpecializationIndex, specializationIndex);
139 
140                         this->visitProgramElement(*decl.definition());
141 
142                         std::swap(fInheritedSpecializationIndex, specializationIndex);
143                         fInheritedSpecializations.swap(specialization);
144                     } else {
145                         // The function being called isn't specialized, but we need to walk the
146                         // entire call graph or we may miss a specialized call entirely. Since
147                         // nothing is specialized, it is safe to skip over repeated traversals.
148                         if (!fVisitedFunctions.find(&decl)) {
149                             fVisitedFunctions.add(&decl);
150                             this->visitProgramElement(*decl.definition());
151                         }
152                     }
153                 }
154             }
155             return INHERITED::visitExpression(expr);
156         }
157 
158     private:
159         SpecializationMap& fSpecializationMap;
160         SpecializedCallMap& fSpecializedCallMap;
161         const ParameterMatchesFn& fParameterMatchesFn;
162         THashSet<const FunctionDeclaration*> fVisitedFunctions;
163 
164         SpecializedParameters fInheritedSpecializations;
165         SpecializationIndex fInheritedSpecializationIndex = kUnspecialized;
166     };
167 
168     for (const ProgramElement* elem : program.elements()) {
169         if (elem->is<FunctionDefinition>()) {
170             const FunctionDeclaration& decl = elem->as<FunctionDefinition>().declaration();
171 
172             if (decl.isMain()) {
173                 // Visit through the program call stack and aggregates any necessary
174                 // function specializations.
175                 Searcher(*info, parameterMatchesFn).visitProgramElement(*elem);
176                 continue;
177             }
178 
179             // Look for any function parameter which needs specialization.
180             for (const Variable* param : decl.parameters()) {
181                 if (parameterMatchesFn(*param)) {
182                     // We found a function that requires specialization. Ensure that this function
183                     // ends up in the specialization map, whether or not it is reachable from
184                     // main().
185                     //
186                     // Doing this here allows unreachable specialized functions to be discarded,
187                     // because it will be in the specialization map with an array of zero necessary
188                     // specializations to emit. If we didn't add this function to the specialization
189                     // map at all, the code generator would try to emit it without applying
190                     // specializations, and generally this would lead to invalid code.
191                     info->fSpecializationMap[&decl];
192                     break;
193                 }
194             }
195         }
196     }
197 }
198 
FindSpecializationIndexForCall(const FunctionCall & call,const SpecializationInfo & info,SpecializationIndex parentSpecializationIndex)199 SpecializationIndex FindSpecializationIndexForCall(const FunctionCall& call,
200                                                    const SpecializationInfo& info,
201                                                    SpecializationIndex parentSpecializationIndex) {
202     SpecializedCallKey callKey{call.stablePointer(), parentSpecializationIndex};
203     SpecializationIndex* foundIndex = info.fSpecializedCallMap.find(callKey);
204     return foundIndex ? *foundIndex : kUnspecialized;
205 }
206 
FindSpecializedParametersForFunction(const FunctionDeclaration & func,const SpecializationInfo & info)207 SkBitSet FindSpecializedParametersForFunction(const FunctionDeclaration& func,
208                                               const SpecializationInfo& info) {
209     SkBitSet result(func.parameters().size());
210     if (const Specializations* specializations = info.fSpecializationMap.find(&func)) {
211         const Analysis::SpecializedParameters& specializedParams = specializations->front();
212         const SkSpan<Variable* const> funcParams = func.parameters();
213 
214         for (size_t index = 0; index < funcParams.size(); ++index) {
215             if (specializedParams.find(funcParams[index])) {
216                 result.set(index);
217             }
218         }
219     }
220 
221     return result;
222 }
223 
GetParameterMappingsForFunction(const FunctionDeclaration & func,const SpecializationInfo & info,SpecializationIndex specializationIndex,const ParameterMappingCallback & callback)224 void GetParameterMappingsForFunction(const FunctionDeclaration& func,
225                                      const SpecializationInfo& info,
226                                      SpecializationIndex specializationIndex,
227                                      const ParameterMappingCallback& callback) {
228     if (specializationIndex != Analysis::kUnspecialized) {
229         if (const Specializations* specializations = info.fSpecializationMap.find(&func)) {
230             const Analysis::SpecializedParameters& specializedParams =
231                     specializations->at(specializationIndex);
232             const SkSpan<Variable* const> funcParams = func.parameters();
233 
234             for (size_t index = 0; index < funcParams.size(); ++index) {
235                 const Variable* param = funcParams[index];
236                 if (const Expression** expr = specializedParams.find(param)) {
237                     callback(index, param, *expr);
238                 }
239             }
240         }
241     }
242 }
243 
244 }  // namespace SkSL::Analysis
245