1 /*
2 * Copyright 2021 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/ir/SkSLFunctionDefinition.h"
9
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "src/base/SkSafeMath.h"
13 #include "src/sksl/SkSLAnalysis.h"
14 #include "src/sksl/SkSLCompiler.h"
15 #include "src/sksl/SkSLContext.h"
16 #include "src/sksl/SkSLDefines.h"
17 #include "src/sksl/SkSLErrorReporter.h"
18 #include "src/sksl/SkSLOperator.h"
19 #include "src/sksl/SkSLProgramSettings.h"
20 #include "src/sksl/ir/SkSLBinaryExpression.h"
21 #include "src/sksl/ir/SkSLBlock.h"
22 #include "src/sksl/ir/SkSLExpression.h"
23 #include "src/sksl/ir/SkSLExpressionStatement.h"
24 #include "src/sksl/ir/SkSLFieldSymbol.h"
25 #include "src/sksl/ir/SkSLIRHelpers.h"
26 #include "src/sksl/ir/SkSLNop.h"
27 #include "src/sksl/ir/SkSLReturnStatement.h"
28 #include "src/sksl/ir/SkSLSwizzle.h"
29 #include "src/sksl/ir/SkSLSymbol.h"
30 #include "src/sksl/ir/SkSLSymbolTable.h" // IWYU pragma: keep
31 #include "src/sksl/ir/SkSLType.h"
32 #include "src/sksl/ir/SkSLVarDeclarations.h"
33 #include "src/sksl/ir/SkSLVariable.h"
34 #include "src/sksl/ir/SkSLVariableReference.h"
35 #include "src/sksl/transform/SkSLProgramWriter.h"
36
37 #include <algorithm>
38 #include <cstddef>
39 #include <forward_list>
40
41 namespace SkSL {
42
append_rtadjust_fixup_to_vertex_main(const Context & context,const FunctionDeclaration & decl,Block & body)43 static void append_rtadjust_fixup_to_vertex_main(const Context& context,
44 const FunctionDeclaration& decl,
45 Block& body) {
46 // If this program uses RTAdjust...
47 if (const SkSL::Symbol* rtAdjust = context.fSymbolTable->find(Compiler::RTADJUST_NAME)) {
48 // ...append a line to the end of the function body which fixes up sk_Position.
49 struct AppendRTAdjustFixupHelper : public IRHelpers {
50 AppendRTAdjustFixupHelper(const Context& ctx, const SkSL::Symbol* rtAdjust)
51 : IRHelpers(ctx)
52 , fRTAdjust(rtAdjust) {
53 fSkPositionField = &fContext.fSymbolTable->find(Compiler::POSITION_NAME)
54 ->as<FieldSymbol>();
55 }
56
57 std::unique_ptr<Expression> Pos() const {
58 return Field(&fSkPositionField->owner(), fSkPositionField->fieldIndex());
59 }
60
61 std::unique_ptr<Expression> Adjust() const {
62 return fRTAdjust->instantiate(fContext, Position());
63 }
64
65 std::unique_ptr<Statement> makeFixupStmt() const {
66 // sk_Position = float4(sk_Position.xy * rtAdjust.xz + sk_Position.ww * rtAdjust.yw,
67 // 0,
68 // sk_Position.w);
69 return Assign(
70 Pos(),
71 CtorXYZW(Add(Mul(Swizzle(Pos(), {SwizzleComponent::X, SwizzleComponent::Y}),
72 Swizzle(Adjust(), {SwizzleComponent::X, SwizzleComponent::Z})),
73 Mul(Swizzle(Pos(), {SwizzleComponent::W, SwizzleComponent::W}),
74 Swizzle(Adjust(), {SwizzleComponent::Y, SwizzleComponent::W}))),
75 Float(0.0),
76 Swizzle(Pos(), {SwizzleComponent::W})));
77 }
78
79 const FieldSymbol* fSkPositionField;
80 const SkSL::Symbol* fRTAdjust;
81 };
82
83 AppendRTAdjustFixupHelper helper(context, rtAdjust);
84 body.children().push_back(helper.makeFixupStmt());
85 }
86 }
87
Convert(const Context & context,Position pos,const FunctionDeclaration & function,std::unique_ptr<Statement> body)88 std::unique_ptr<FunctionDefinition> FunctionDefinition::Convert(const Context& context,
89 Position pos,
90 const FunctionDeclaration& function,
91 std::unique_ptr<Statement> body) {
92 class Finalizer : public ProgramWriter {
93 public:
94 Finalizer(const Context& context, const FunctionDeclaration& function, Position pos)
95 : fContext(context)
96 , fFunction(function) {
97 // Function parameters count as local variables.
98 for (const Variable* var : function.parameters()) {
99 this->addLocalVariable(var, pos);
100 }
101 }
102
103 ~Finalizer() override {
104 SkASSERT(fBreakableLevel == 0);
105 SkASSERT(fContinuableLevel == std::forward_list<int>{0});
106 }
107
108 void addLocalVariable(const Variable* var, Position pos) {
109 if (var->type().isOrContainsUnsizedArray()) {
110 if (var->storage() != Variable::Storage::kParameter) {
111 fContext.fErrors->error(pos, "unsized arrays are not permitted here");
112 }
113 // Number of slots does not apply to unsized arrays since they are
114 // dynamically sized.
115 return;
116 }
117 // We count the number of slots used, but don't consider the precision of the base type.
118 // In practice, this reflects what GPUs actually do pretty well. (i.e., RelaxedPrecision
119 // math doesn't mean your variable takes less space.) We also don't attempt to reclaim
120 // slots at the end of a Block.
121 size_t prevSlotsUsed = fSlotsUsed;
122 fSlotsUsed = SkSafeMath::Add(fSlotsUsed, var->type().slotCount());
123 // To avoid overzealous error reporting, only trigger the error at the first
124 // place where the stack limit is exceeded.
125 if (prevSlotsUsed < kVariableSlotLimit && fSlotsUsed >= kVariableSlotLimit) {
126 fContext.fErrors->error(pos, "variable '" + std::string(var->name()) +
127 "' exceeds the stack size limit");
128 }
129 }
130
131 void fuseVariableDeclarationsWithInitialization(std::unique_ptr<Statement>& stmt) {
132 switch (stmt->kind()) {
133 case Statement::Kind::kNop:
134 case Statement::Kind::kBlock:
135 // Blocks and no-ops are inert; it is safe to fuse a variable declaration with
136 // its initialization across a nop or an open-brace, so we don't null out
137 // `fUninitializedVarDecl` here.
138 break;
139
140 case Statement::Kind::kVarDeclaration:
141 // Look for variable declarations without an initializer.
142 if (VarDeclaration& decl = stmt->as<VarDeclaration>(); !decl.value()) {
143 fUninitializedVarDecl = &decl;
144 break;
145 }
146 [[fallthrough]];
147
148 default:
149 // We found an intervening statement; it's not safe to fuse a declaration
150 // with an initializer if we encounter any other code.
151 fUninitializedVarDecl = nullptr;
152 break;
153
154 case Statement::Kind::kExpression: {
155 // We found an expression-statement. If there was a variable declaration
156 // immediately above it, it might be possible to fuse them.
157 if (fUninitializedVarDecl) {
158 VarDeclaration* vardecl = fUninitializedVarDecl;
159 fUninitializedVarDecl = nullptr;
160
161 std::unique_ptr<Expression>& nextExpr = stmt->as<ExpressionStatement>()
162 .expression();
163 // This statement must be a binary-expression...
164 if (!nextExpr->is<BinaryExpression>()) {
165 break;
166 }
167 // ... performing simple `var = expr` assignment...
168 BinaryExpression& binaryExpr = nextExpr->as<BinaryExpression>();
169 if (binaryExpr.getOperator().kind() != OperatorKind::EQ) {
170 break;
171 }
172 // ... directly into the variable (not a field/swizzle)...
173 Expression& leftExpr = *binaryExpr.left();
174 if (!leftExpr.is<VariableReference>()) {
175 break;
176 }
177 // ... and it must be the same variable as our vardecl.
178 VariableReference& varRef = leftExpr.as<VariableReference>();
179 if (varRef.variable() != vardecl->var()) {
180 break;
181 }
182 // The init-expression must not reference the variable.
183 // `int x; x = x = 0;` is legal SkSL, but `int x = x = 0;` is not.
184 if (Analysis::ContainsVariable(*binaryExpr.right(), *varRef.variable())) {
185 break;
186 }
187 // We found a match! Move the init-expression directly onto the vardecl, and
188 // turn the assignment into a no-op.
189 vardecl->value() = std::move(binaryExpr.right());
190
191 // Turn the expression-statement into a no-op.
192 stmt = Nop::Make();
193 }
194 break;
195 }
196 }
197 }
198
199 bool functionReturnsValue() const {
200 return !fFunction.returnType().isVoid();
201 }
202
203 bool visitExpressionPtr(std::unique_ptr<Expression>& expr) override {
204 // We don't need to scan expressions.
205 return false;
206 }
207
208 bool visitStatementPtr(std::unique_ptr<Statement>& stmt) override {
209 // When the optimizer is on, we look for variable declarations that are immediately
210 // followed by an initialization expression, and fuse them into one statement.
211 // (e.g.: `int i; i = 1;` can become `int i = 1;`)
212 if (fContext.fConfig->fSettings.fOptimize) {
213 this->fuseVariableDeclarationsWithInitialization(stmt);
214 }
215
216 // Perform error checking.
217 switch (stmt->kind()) {
218 case Statement::Kind::kVarDeclaration:
219 this->addLocalVariable(stmt->as<VarDeclaration>().var(), stmt->fPosition);
220 break;
221
222 case Statement::Kind::kReturn: {
223 // Early returns from a vertex main() function will bypass sk_Position
224 // normalization, so SkASSERT that we aren't doing that. If this becomes an
225 // issue, we can add normalization before each return statement.
226 if (ProgramConfig::IsVertex(fContext.fConfig->fKind) && fFunction.isMain()) {
227 fContext.fErrors->error(
228 stmt->fPosition,
229 "early returns from vertex programs are not supported");
230 }
231
232 // Verify that the return statement matches the function's return type.
233 ReturnStatement& returnStmt = stmt->as<ReturnStatement>();
234 if (returnStmt.expression()) {
235 if (this->functionReturnsValue()) {
236 // Coerce return expression to the function's return type.
237 returnStmt.setExpression(fFunction.returnType().coerceExpression(
238 std::move(returnStmt.expression()), fContext));
239 } else {
240 // Returning something from a function with a void return type.
241 fContext.fErrors->error(returnStmt.expression()->fPosition,
242 "may not return a value from a void function");
243 returnStmt.setExpression(nullptr);
244 }
245 } else {
246 if (this->functionReturnsValue()) {
247 // Returning nothing from a function with a non-void return type.
248 fContext.fErrors->error(returnStmt.fPosition,
249 "expected function to return '" +
250 fFunction.returnType().displayName() + "'");
251 }
252 }
253 break;
254 }
255 case Statement::Kind::kDo:
256 case Statement::Kind::kFor: {
257 ++fBreakableLevel;
258 ++fContinuableLevel.front();
259 bool result = INHERITED::visitStatementPtr(stmt);
260 --fContinuableLevel.front();
261 --fBreakableLevel;
262 return result;
263 }
264 case Statement::Kind::kSwitch: {
265 ++fBreakableLevel;
266 fContinuableLevel.push_front(0);
267 bool result = INHERITED::visitStatementPtr(stmt);
268 fContinuableLevel.pop_front();
269 --fBreakableLevel;
270 return result;
271 }
272 case Statement::Kind::kBreak:
273 if (fBreakableLevel == 0) {
274 fContext.fErrors->error(stmt->fPosition,
275 "break statement must be inside a loop or switch");
276 }
277 break;
278
279 case Statement::Kind::kContinue:
280 if (fContinuableLevel.front() == 0) {
281 if (std::any_of(fContinuableLevel.begin(),
282 fContinuableLevel.end(),
283 [](int level) { return level > 0; })) {
284 fContext.fErrors->error(stmt->fPosition,
285 "continue statement cannot be used in a switch");
286 } else {
287 fContext.fErrors->error(stmt->fPosition,
288 "continue statement must be inside a loop");
289 }
290 }
291 break;
292
293 default:
294 break;
295 }
296 return INHERITED::visitStatementPtr(stmt);
297 }
298
299 private:
300 const Context& fContext;
301 const FunctionDeclaration& fFunction;
302 // how deeply nested we are in breakable constructs (for, do, switch).
303 int fBreakableLevel = 0;
304 // number of slots consumed by all variables declared in the function
305 size_t fSlotsUsed = 0;
306 // how deeply nested we are in continuable constructs (for, do).
307 // We keep a stack (via a forward_list) in order to disallow continue inside of switch.
308 std::forward_list<int> fContinuableLevel{0};
309 // We track uninitialized variable declarations, and if they are immediately assigned-to,
310 // we can move the assignment directly into the decl.
311 VarDeclaration* fUninitializedVarDecl = nullptr;
312
313 using INHERITED = ProgramWriter;
314 };
315
316 // We don't allow modules to define actual functions with intrinsic names. (Those should be
317 // reserved for actual intrinsics.)
318 if (function.isIntrinsic()) {
319 context.fErrors->error(pos, "intrinsic function '" + std::string(function.name()) +
320 "' should not have a definition");
321 return nullptr;
322 }
323
324 // A function body must always be a braced block. (The parser should enforce this already, but
325 // we rely on it, so it's best to be certain.)
326 if (!body || !body->is<Block>() || !body->as<Block>().isScope()) {
327 context.fErrors->error(pos, "function body '" + function.description() +
328 "' must be a braced block");
329 return nullptr;
330 }
331
332 // A function can't have more than one definition.
333 if (function.definition()) {
334 context.fErrors->error(pos, "function '" + function.description() +
335 "' was already defined");
336 return nullptr;
337 }
338
339 // Run the function finalizer. This checks for illegal constructs and missing return statements,
340 // and also performs some simple code cleanup.
341 Finalizer(context, function, pos).visitStatementPtr(body);
342 if (function.isMain() && ProgramConfig::IsVertex(context.fConfig->fKind)) {
343 append_rtadjust_fixup_to_vertex_main(context, function, body->as<Block>());
344 }
345
346 if (Analysis::CanExitWithoutReturningValue(function, *body)) {
347 context.fErrors->error(body->fPosition, "function '" + std::string(function.name()) +
348 "' can exit without returning a value");
349 }
350
351 return FunctionDefinition::Make(context, pos, function, std::move(body));
352 }
353
Make(const Context & context,Position pos,const FunctionDeclaration & function,std::unique_ptr<Statement> body)354 std::unique_ptr<FunctionDefinition> FunctionDefinition::Make(const Context& context,
355 Position pos,
356 const FunctionDeclaration& function,
357 std::unique_ptr<Statement> body) {
358 SkASSERT(!function.isIntrinsic());
359 SkASSERT(body && body->as<Block>().isScope());
360 SkASSERT(!function.definition());
361
362 return std::make_unique<FunctionDefinition>(pos, &function, std::move(body));
363 }
364
365 } // namespace SkSL
366