1 /*
2 * Copyright 2024 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/analysis/SkSLSpecialization.h"
9
10 #include "include/private/base/SkAssert.h"
11 #include "include/private/base/SkSpan_impl.h"
12 #include "src/sksl/SkSLAnalysis.h"
13 #include "src/sksl/SkSLDefines.h"
14 #include "src/sksl/analysis/SkSLProgramVisitor.h"
15 #include "src/sksl/ir/SkSLExpression.h"
16 #include "src/sksl/ir/SkSLFieldAccess.h"
17 #include "src/sksl/ir/SkSLFunctionCall.h"
18 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
19 #include "src/sksl/ir/SkSLFunctionDefinition.h"
20 #include "src/sksl/ir/SkSLProgram.h"
21 #include "src/sksl/ir/SkSLProgramElement.h"
22 #include "src/sksl/ir/SkSLVariable.h"
23 #include "src/sksl/ir/SkSLVariableReference.h"
24
25 #include <algorithm>
26 #include <memory>
27
28 using namespace skia_private;
29
30 namespace SkSL::Analysis {
31
parameter_mappings_are_equal(const SpecializedParameters & left,const SpecializedParameters & right)32 static bool parameter_mappings_are_equal(const SpecializedParameters& left,
33 const SpecializedParameters& right) {
34 if (left.count() != right.count()) {
35 return false;
36 }
37 for (const auto& [key, leftExpr] : left) {
38 const Expression** rightExpr = right.find(key);
39 if (!rightExpr) {
40 return false;
41 }
42 if (!Analysis::IsSameExpressionTree(*leftExpr, **rightExpr)) {
43 return false;
44 }
45 }
46 return true;
47 }
48
FindFunctionsToSpecialize(const Program & program,SpecializationInfo * info,const ParameterMatchesFn & parameterMatchesFn)49 void FindFunctionsToSpecialize(const Program& program,
50 SpecializationInfo* info,
51 const ParameterMatchesFn& parameterMatchesFn) {
52 class Searcher : public ProgramVisitor {
53 public:
54 using ProgramVisitor::visitProgramElement;
55 using INHERITED = ProgramVisitor;
56
57 Searcher(SpecializationInfo& info, const ParameterMatchesFn& parameterMatchesFn)
58 : fSpecializationMap(info.fSpecializationMap)
59 , fSpecializedCallMap(info.fSpecializedCallMap)
60 , fParameterMatchesFn(parameterMatchesFn) {}
61
62 bool visitExpression(const Expression& expr) override {
63 if (expr.is<FunctionCall>()) {
64 const FunctionCall& call = expr.as<FunctionCall>();
65 const FunctionDeclaration& decl = call.function();
66
67 if (!decl.isIntrinsic()) {
68 SpecializedParameters specialization;
69
70 const int numParams = decl.parameters().size();
71 SkASSERT(call.arguments().size() == numParams);
72
73 for (int i = 0; i < numParams; i++) {
74 const Expression& arg = *call.arguments()[i];
75
76 // Specializations can only be made on arguments that are not complex
77 // expressions but only a variable reference or field access since these
78 // references will be inlined in the generated specialized functions.
79 const Variable* argBase = nullptr;
80 if (arg.is<VariableReference>()) {
81 argBase = arg.as<VariableReference>().variable();
82 } else if (arg.is<FieldAccess>() &&
83 arg.as<FieldAccess>().base()->is<VariableReference>()) {
84 argBase =
85 arg.as<FieldAccess>().base()->as<VariableReference>().variable();
86 } else {
87 continue;
88 }
89 SkASSERT(argBase);
90
91 const Variable* param = decl.parameters()[i];
92
93 // Check that this parameter fits the criteria to create a specialization.
94 if (!fParameterMatchesFn(*param)) {
95 continue;
96 }
97
98 if (argBase->storage() == Variable::Storage::kGlobal) {
99 specialization[param] = &arg;
100 } else if (argBase->storage() == Variable::Storage::kParameter) {
101 const Expression** uniformExpr =
102 fInheritedSpecializations.find(argBase);
103 SkASSERT(uniformExpr);
104
105 specialization[param] = *uniformExpr;
106 } else {
107 // TODO(b/353532475): Report an error instead of aborting.
108 SK_ABORT("Specialization requires a uniform or parameter variable");
109 }
110 }
111
112 // Only create a specialization for this function if there are
113 // variables to specialize on.
114 if (specialization.count() > 0) {
115 Specializations& specializations = fSpecializationMap[&decl];
116 SpecializedCallKey callKey{call.stablePointer(),
117 fInheritedSpecializationIndex};
118
119 for (int i = 0; i < specializations.size(); i++) {
120 const SpecializedParameters& entry = specializations[i];
121 if (parameter_mappings_are_equal(specialization, entry)) {
122 // This specialization has already been tracked.
123 fSpecializedCallMap[callKey] = i;
124 return INHERITED::visitExpression(expr);
125 }
126 }
127
128 // Set the index to the corresponding specialization this function call
129 // requires, also tracking the inherited specialization this function
130 // call is in so the right specialized function can be called.
131 SpecializationIndex specializationIndex = specializations.size();
132 fSpecializedCallMap[callKey] = specializationIndex;
133 specializations.push_back(specialization);
134
135 // We swap so we don't lose when our last inherited specializations were
136 // once we are done traversing this specific specialization.
137 fInheritedSpecializations.swap(specialization);
138 std::swap(fInheritedSpecializationIndex, specializationIndex);
139
140 this->visitProgramElement(*decl.definition());
141
142 std::swap(fInheritedSpecializationIndex, specializationIndex);
143 fInheritedSpecializations.swap(specialization);
144 } else {
145 // The function being called isn't specialized, but we need to walk the
146 // entire call graph or we may miss a specialized call entirely. Since
147 // nothing is specialized, it is safe to skip over repeated traversals.
148 if (!fVisitedFunctions.find(&decl)) {
149 fVisitedFunctions.add(&decl);
150 this->visitProgramElement(*decl.definition());
151 }
152 }
153 }
154 }
155 return INHERITED::visitExpression(expr);
156 }
157
158 private:
159 SpecializationMap& fSpecializationMap;
160 SpecializedCallMap& fSpecializedCallMap;
161 const ParameterMatchesFn& fParameterMatchesFn;
162 THashSet<const FunctionDeclaration*> fVisitedFunctions;
163
164 SpecializedParameters fInheritedSpecializations;
165 SpecializationIndex fInheritedSpecializationIndex = kUnspecialized;
166 };
167
168 for (const ProgramElement* elem : program.elements()) {
169 if (elem->is<FunctionDefinition>()) {
170 const FunctionDeclaration& decl = elem->as<FunctionDefinition>().declaration();
171
172 if (decl.isMain()) {
173 // Visit through the program call stack and aggregates any necessary
174 // function specializations.
175 Searcher(*info, parameterMatchesFn).visitProgramElement(*elem);
176 continue;
177 }
178
179 // Look for any function parameter which needs specialization.
180 for (const Variable* param : decl.parameters()) {
181 if (parameterMatchesFn(*param)) {
182 // We found a function that requires specialization. Ensure that this function
183 // ends up in the specialization map, whether or not it is reachable from
184 // main().
185 //
186 // Doing this here allows unreachable specialized functions to be discarded,
187 // because it will be in the specialization map with an array of zero necessary
188 // specializations to emit. If we didn't add this function to the specialization
189 // map at all, the code generator would try to emit it without applying
190 // specializations, and generally this would lead to invalid code.
191 info->fSpecializationMap[&decl];
192 break;
193 }
194 }
195 }
196 }
197 }
198
FindSpecializationIndexForCall(const FunctionCall & call,const SpecializationInfo & info,SpecializationIndex parentSpecializationIndex)199 SpecializationIndex FindSpecializationIndexForCall(const FunctionCall& call,
200 const SpecializationInfo& info,
201 SpecializationIndex parentSpecializationIndex) {
202 SpecializedCallKey callKey{call.stablePointer(), parentSpecializationIndex};
203 SpecializationIndex* foundIndex = info.fSpecializedCallMap.find(callKey);
204 return foundIndex ? *foundIndex : kUnspecialized;
205 }
206
FindSpecializedParametersForFunction(const FunctionDeclaration & func,const SpecializationInfo & info)207 SkBitSet FindSpecializedParametersForFunction(const FunctionDeclaration& func,
208 const SpecializationInfo& info) {
209 SkBitSet result(func.parameters().size());
210 if (const Specializations* specializations = info.fSpecializationMap.find(&func)) {
211 const Analysis::SpecializedParameters& specializedParams = specializations->front();
212 const SkSpan<Variable* const> funcParams = func.parameters();
213
214 for (size_t index = 0; index < funcParams.size(); ++index) {
215 if (specializedParams.find(funcParams[index])) {
216 result.set(index);
217 }
218 }
219 }
220
221 return result;
222 }
223
GetParameterMappingsForFunction(const FunctionDeclaration & func,const SpecializationInfo & info,SpecializationIndex specializationIndex,const ParameterMappingCallback & callback)224 void GetParameterMappingsForFunction(const FunctionDeclaration& func,
225 const SpecializationInfo& info,
226 SpecializationIndex specializationIndex,
227 const ParameterMappingCallback& callback) {
228 if (specializationIndex != Analysis::kUnspecialized) {
229 if (const Specializations* specializations = info.fSpecializationMap.find(&func)) {
230 const Analysis::SpecializedParameters& specializedParams =
231 specializations->at(specializationIndex);
232 const SkSpan<Variable* const> funcParams = func.parameters();
233
234 for (size_t index = 0; index < funcParams.size(); ++index) {
235 const Variable* param = funcParams[index];
236 if (const Expression** expr = specializedParams.find(param)) {
237 callback(index, param, *expr);
238 }
239 }
240 }
241 }
242 }
243
244 } // namespace SkSL::Analysis
245