xref: /aosp_15_r20/external/skia/src/sksl/ir/SkSLFunctionDefinition.cpp (revision c8dee2aa9b3f27cf6c858bd81872bdeb2c07ed17)
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/ir/SkSLFunctionDefinition.h"
9 
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "src/base/SkSafeMath.h"
13 #include "src/sksl/SkSLAnalysis.h"
14 #include "src/sksl/SkSLCompiler.h"
15 #include "src/sksl/SkSLContext.h"
16 #include "src/sksl/SkSLDefines.h"
17 #include "src/sksl/SkSLErrorReporter.h"
18 #include "src/sksl/SkSLOperator.h"
19 #include "src/sksl/SkSLProgramSettings.h"
20 #include "src/sksl/ir/SkSLBinaryExpression.h"
21 #include "src/sksl/ir/SkSLBlock.h"
22 #include "src/sksl/ir/SkSLExpression.h"
23 #include "src/sksl/ir/SkSLExpressionStatement.h"
24 #include "src/sksl/ir/SkSLFieldSymbol.h"
25 #include "src/sksl/ir/SkSLIRHelpers.h"
26 #include "src/sksl/ir/SkSLNop.h"
27 #include "src/sksl/ir/SkSLReturnStatement.h"
28 #include "src/sksl/ir/SkSLSwizzle.h"
29 #include "src/sksl/ir/SkSLSymbol.h"
30 #include "src/sksl/ir/SkSLSymbolTable.h"  // IWYU pragma: keep
31 #include "src/sksl/ir/SkSLType.h"
32 #include "src/sksl/ir/SkSLVarDeclarations.h"
33 #include "src/sksl/ir/SkSLVariable.h"
34 #include "src/sksl/ir/SkSLVariableReference.h"
35 #include "src/sksl/transform/SkSLProgramWriter.h"
36 
37 #include <algorithm>
38 #include <cstddef>
39 #include <forward_list>
40 
41 namespace SkSL {
42 
append_rtadjust_fixup_to_vertex_main(const Context & context,const FunctionDeclaration & decl,Block & body)43 static void append_rtadjust_fixup_to_vertex_main(const Context& context,
44                                                  const FunctionDeclaration& decl,
45                                                  Block& body) {
46     // If this program uses RTAdjust...
47     if (const SkSL::Symbol* rtAdjust = context.fSymbolTable->find(Compiler::RTADJUST_NAME)) {
48         // ...append a line to the end of the function body which fixes up sk_Position.
49         struct AppendRTAdjustFixupHelper : public IRHelpers {
50             AppendRTAdjustFixupHelper(const Context& ctx, const SkSL::Symbol* rtAdjust)
51                     : IRHelpers(ctx)
52                     , fRTAdjust(rtAdjust) {
53                 fSkPositionField = &fContext.fSymbolTable->find(Compiler::POSITION_NAME)
54                                                          ->as<FieldSymbol>();
55             }
56 
57             std::unique_ptr<Expression> Pos() const {
58                 return Field(&fSkPositionField->owner(), fSkPositionField->fieldIndex());
59             }
60 
61             std::unique_ptr<Expression> Adjust() const {
62                 return fRTAdjust->instantiate(fContext, Position());
63             }
64 
65             std::unique_ptr<Statement> makeFixupStmt() const {
66                 // sk_Position = float4(sk_Position.xy * rtAdjust.xz + sk_Position.ww * rtAdjust.yw,
67                 //                      0,
68                 //                      sk_Position.w);
69                 return Assign(
70                    Pos(),
71                    CtorXYZW(Add(Mul(Swizzle(Pos(),    {SwizzleComponent::X, SwizzleComponent::Y}),
72                                     Swizzle(Adjust(), {SwizzleComponent::X, SwizzleComponent::Z})),
73                                 Mul(Swizzle(Pos(),    {SwizzleComponent::W, SwizzleComponent::W}),
74                                     Swizzle(Adjust(), {SwizzleComponent::Y, SwizzleComponent::W}))),
75                             Float(0.0),
76                             Swizzle(Pos(), {SwizzleComponent::W})));
77             }
78 
79             const FieldSymbol* fSkPositionField;
80             const SkSL::Symbol* fRTAdjust;
81         };
82 
83         AppendRTAdjustFixupHelper helper(context, rtAdjust);
84         body.children().push_back(helper.makeFixupStmt());
85     }
86 }
87 
Convert(const Context & context,Position pos,const FunctionDeclaration & function,std::unique_ptr<Statement> body)88 std::unique_ptr<FunctionDefinition> FunctionDefinition::Convert(const Context& context,
89                                                                 Position pos,
90                                                                 const FunctionDeclaration& function,
91                                                                 std::unique_ptr<Statement> body) {
92     class Finalizer : public ProgramWriter {
93     public:
94         Finalizer(const Context& context, const FunctionDeclaration& function, Position pos)
95             : fContext(context)
96             , fFunction(function) {
97             // Function parameters count as local variables.
98             for (const Variable* var : function.parameters()) {
99                 this->addLocalVariable(var, pos);
100             }
101         }
102 
103         ~Finalizer() override {
104             SkASSERT(fBreakableLevel == 0);
105             SkASSERT(fContinuableLevel == std::forward_list<int>{0});
106         }
107 
108         void addLocalVariable(const Variable* var, Position pos) {
109             if (var->type().isOrContainsUnsizedArray()) {
110                 if (var->storage() != Variable::Storage::kParameter) {
111                     fContext.fErrors->error(pos, "unsized arrays are not permitted here");
112                 }
113                 // Number of slots does not apply to unsized arrays since they are
114                 // dynamically sized.
115                 return;
116             }
117             // We count the number of slots used, but don't consider the precision of the base type.
118             // In practice, this reflects what GPUs actually do pretty well. (i.e., RelaxedPrecision
119             // math doesn't mean your variable takes less space.) We also don't attempt to reclaim
120             // slots at the end of a Block.
121             size_t prevSlotsUsed = fSlotsUsed;
122             fSlotsUsed = SkSafeMath::Add(fSlotsUsed, var->type().slotCount());
123             // To avoid overzealous error reporting, only trigger the error at the first
124             // place where the stack limit is exceeded.
125             if (prevSlotsUsed < kVariableSlotLimit && fSlotsUsed >= kVariableSlotLimit) {
126                 fContext.fErrors->error(pos, "variable '" + std::string(var->name()) +
127                                              "' exceeds the stack size limit");
128             }
129         }
130 
131         void fuseVariableDeclarationsWithInitialization(std::unique_ptr<Statement>& stmt) {
132             switch (stmt->kind()) {
133                 case Statement::Kind::kNop:
134                 case Statement::Kind::kBlock:
135                     // Blocks and no-ops are inert; it is safe to fuse a variable declaration with
136                     // its initialization across a nop or an open-brace, so we don't null out
137                     // `fUninitializedVarDecl` here.
138                     break;
139 
140                 case Statement::Kind::kVarDeclaration:
141                     // Look for variable declarations without an initializer.
142                     if (VarDeclaration& decl = stmt->as<VarDeclaration>(); !decl.value()) {
143                         fUninitializedVarDecl = &decl;
144                         break;
145                     }
146                     [[fallthrough]];
147 
148                 default:
149                     // We found an intervening statement; it's not safe to fuse a declaration
150                     // with an initializer if we encounter any other code.
151                     fUninitializedVarDecl = nullptr;
152                     break;
153 
154                 case Statement::Kind::kExpression: {
155                     // We found an expression-statement. If there was a variable declaration
156                     // immediately above it, it might be possible to fuse them.
157                     if (fUninitializedVarDecl) {
158                         VarDeclaration* vardecl = fUninitializedVarDecl;
159                         fUninitializedVarDecl = nullptr;
160 
161                         std::unique_ptr<Expression>& nextExpr = stmt->as<ExpressionStatement>()
162                                                                      .expression();
163                         // This statement must be a binary-expression...
164                         if (!nextExpr->is<BinaryExpression>()) {
165                             break;
166                         }
167                         // ... performing simple `var = expr` assignment...
168                         BinaryExpression& binaryExpr = nextExpr->as<BinaryExpression>();
169                         if (binaryExpr.getOperator().kind() != OperatorKind::EQ) {
170                             break;
171                         }
172                         // ... directly into the variable (not a field/swizzle)...
173                         Expression& leftExpr = *binaryExpr.left();
174                         if (!leftExpr.is<VariableReference>()) {
175                             break;
176                         }
177                         // ... and it must be the same variable as our vardecl.
178                         VariableReference& varRef = leftExpr.as<VariableReference>();
179                         if (varRef.variable() != vardecl->var()) {
180                             break;
181                         }
182                         // The init-expression must not reference the variable.
183                         // `int x; x = x = 0;` is legal SkSL, but `int x = x = 0;` is not.
184                         if (Analysis::ContainsVariable(*binaryExpr.right(), *varRef.variable())) {
185                             break;
186                         }
187                         // We found a match! Move the init-expression directly onto the vardecl, and
188                         // turn the assignment into a no-op.
189                         vardecl->value() = std::move(binaryExpr.right());
190 
191                         // Turn the expression-statement into a no-op.
192                         stmt = Nop::Make();
193                     }
194                     break;
195                 }
196             }
197         }
198 
199         bool functionReturnsValue() const {
200             return !fFunction.returnType().isVoid();
201         }
202 
203         bool visitExpressionPtr(std::unique_ptr<Expression>& expr) override {
204             // We don't need to scan expressions.
205             return false;
206         }
207 
208         bool visitStatementPtr(std::unique_ptr<Statement>& stmt) override {
209             // When the optimizer is on, we look for variable declarations that are immediately
210             // followed by an initialization expression, and fuse them into one statement.
211             // (e.g.: `int i; i = 1;` can become `int i = 1;`)
212             if (fContext.fConfig->fSettings.fOptimize) {
213                 this->fuseVariableDeclarationsWithInitialization(stmt);
214             }
215 
216             // Perform error checking.
217             switch (stmt->kind()) {
218                 case Statement::Kind::kVarDeclaration:
219                     this->addLocalVariable(stmt->as<VarDeclaration>().var(), stmt->fPosition);
220                     break;
221 
222                 case Statement::Kind::kReturn: {
223                     // Early returns from a vertex main() function will bypass sk_Position
224                     // normalization, so SkASSERT that we aren't doing that. If this becomes an
225                     // issue, we can add normalization before each return statement.
226                     if (ProgramConfig::IsVertex(fContext.fConfig->fKind) && fFunction.isMain()) {
227                         fContext.fErrors->error(
228                                 stmt->fPosition,
229                                 "early returns from vertex programs are not supported");
230                     }
231 
232                     // Verify that the return statement matches the function's return type.
233                     ReturnStatement& returnStmt = stmt->as<ReturnStatement>();
234                     if (returnStmt.expression()) {
235                         if (this->functionReturnsValue()) {
236                             // Coerce return expression to the function's return type.
237                             returnStmt.setExpression(fFunction.returnType().coerceExpression(
238                                     std::move(returnStmt.expression()), fContext));
239                         } else {
240                             // Returning something from a function with a void return type.
241                             fContext.fErrors->error(returnStmt.expression()->fPosition,
242                                                     "may not return a value from a void function");
243                             returnStmt.setExpression(nullptr);
244                         }
245                     } else {
246                         if (this->functionReturnsValue()) {
247                             // Returning nothing from a function with a non-void return type.
248                             fContext.fErrors->error(returnStmt.fPosition,
249                                                     "expected function to return '" +
250                                                     fFunction.returnType().displayName() + "'");
251                         }
252                     }
253                     break;
254                 }
255                 case Statement::Kind::kDo:
256                 case Statement::Kind::kFor: {
257                     ++fBreakableLevel;
258                     ++fContinuableLevel.front();
259                     bool result = INHERITED::visitStatementPtr(stmt);
260                     --fContinuableLevel.front();
261                     --fBreakableLevel;
262                     return result;
263                 }
264                 case Statement::Kind::kSwitch: {
265                     ++fBreakableLevel;
266                     fContinuableLevel.push_front(0);
267                     bool result = INHERITED::visitStatementPtr(stmt);
268                     fContinuableLevel.pop_front();
269                     --fBreakableLevel;
270                     return result;
271                 }
272                 case Statement::Kind::kBreak:
273                     if (fBreakableLevel == 0) {
274                         fContext.fErrors->error(stmt->fPosition,
275                                                 "break statement must be inside a loop or switch");
276                     }
277                     break;
278 
279                 case Statement::Kind::kContinue:
280                     if (fContinuableLevel.front() == 0) {
281                         if (std::any_of(fContinuableLevel.begin(),
282                                         fContinuableLevel.end(),
283                                         [](int level) { return level > 0; })) {
284                             fContext.fErrors->error(stmt->fPosition,
285                                                    "continue statement cannot be used in a switch");
286                         } else {
287                             fContext.fErrors->error(stmt->fPosition,
288                                                     "continue statement must be inside a loop");
289                         }
290                     }
291                     break;
292 
293                 default:
294                     break;
295             }
296             return INHERITED::visitStatementPtr(stmt);
297         }
298 
299     private:
300         const Context& fContext;
301         const FunctionDeclaration& fFunction;
302         // how deeply nested we are in breakable constructs (for, do, switch).
303         int fBreakableLevel = 0;
304         // number of slots consumed by all variables declared in the function
305         size_t fSlotsUsed = 0;
306         // how deeply nested we are in continuable constructs (for, do).
307         // We keep a stack (via a forward_list) in order to disallow continue inside of switch.
308         std::forward_list<int> fContinuableLevel{0};
309         // We track uninitialized variable declarations, and if they are immediately assigned-to,
310         // we can move the assignment directly into the decl.
311         VarDeclaration* fUninitializedVarDecl = nullptr;
312 
313         using INHERITED = ProgramWriter;
314     };
315 
316     // We don't allow modules to define actual functions with intrinsic names. (Those should be
317     // reserved for actual intrinsics.)
318     if (function.isIntrinsic()) {
319         context.fErrors->error(pos, "intrinsic function '" + std::string(function.name()) +
320                                     "' should not have a definition");
321         return nullptr;
322     }
323 
324     // A function body must always be a braced block. (The parser should enforce this already, but
325     // we rely on it, so it's best to be certain.)
326     if (!body || !body->is<Block>() || !body->as<Block>().isScope()) {
327         context.fErrors->error(pos, "function body '" + function.description() +
328                                     "' must be a braced block");
329         return nullptr;
330     }
331 
332     // A function can't have more than one definition.
333     if (function.definition()) {
334         context.fErrors->error(pos, "function '" + function.description() +
335                                     "' was already defined");
336         return nullptr;
337     }
338 
339     // Run the function finalizer. This checks for illegal constructs and missing return statements,
340     // and also performs some simple code cleanup.
341     Finalizer(context, function, pos).visitStatementPtr(body);
342     if (function.isMain() && ProgramConfig::IsVertex(context.fConfig->fKind)) {
343         append_rtadjust_fixup_to_vertex_main(context, function, body->as<Block>());
344     }
345 
346     if (Analysis::CanExitWithoutReturningValue(function, *body)) {
347         context.fErrors->error(body->fPosition, "function '" + std::string(function.name()) +
348                                                 "' can exit without returning a value");
349     }
350 
351     return FunctionDefinition::Make(context, pos, function, std::move(body));
352 }
353 
Make(const Context & context,Position pos,const FunctionDeclaration & function,std::unique_ptr<Statement> body)354 std::unique_ptr<FunctionDefinition> FunctionDefinition::Make(const Context& context,
355                                                              Position pos,
356                                                              const FunctionDeclaration& function,
357                                                              std::unique_ptr<Statement> body) {
358     SkASSERT(!function.isIntrinsic());
359     SkASSERT(body && body->as<Block>().isScope());
360     SkASSERT(!function.definition());
361 
362     return std::make_unique<FunctionDefinition>(pos, &function, std::move(body));
363 }
364 
365 }  // namespace SkSL
366