xref: /aosp_15_r20/external/skia/src/sksl/transform/SkSLEliminateUnnecessaryBraces.cpp (revision c8dee2aa9b3f27cf6c858bd81872bdeb2c07ed17)
1 /*
2  * Copyright 2024 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/SkSLDefines.h"
9 #include "src/sksl/SkSLModule.h"
10 #include "src/sksl/SkSLPosition.h"
11 #include "src/sksl/ir/SkSLBlock.h"
12 #include "src/sksl/ir/SkSLDoStatement.h"
13 #include "src/sksl/ir/SkSLExpression.h"
14 #include "src/sksl/ir/SkSLForStatement.h"
15 #include "src/sksl/ir/SkSLFunctionDefinition.h"
16 #include "src/sksl/ir/SkSLIRNode.h"
17 #include "src/sksl/ir/SkSLIfStatement.h"
18 #include "src/sksl/ir/SkSLNop.h"
19 #include "src/sksl/ir/SkSLProgramElement.h"
20 #include "src/sksl/ir/SkSLStatement.h"
21 #include "src/sksl/transform/SkSLProgramWriter.h"
22 #include "src/sksl/transform/SkSLTransform.h"
23 
24 #include <memory>
25 #include <utility>
26 #include <vector>
27 
28 namespace SkSL {
29 
30 class Context;
31 
EliminateUnnecessaryBraces(const Context & context,Module & module)32 void Transform::EliminateUnnecessaryBraces(const Context& context, Module& module) {
33     class UnnecessaryBraceEliminator : public ProgramWriter {
34     public:
35         bool visitExpressionPtr(std::unique_ptr<Expression>& expr) override {
36             // We don't need to look inside expressions at all.
37             return false;
38         }
39 
40         bool visitStatementPtr(std::unique_ptr<Statement>& stmt) override {
41             // Work from the innermost blocks to the outermost.
42             INHERITED::visitStatementPtr(stmt);
43 
44             switch (stmt->kind()) {
45                 case StatementKind::kIf: {
46                     IfStatement& ifStmt = stmt->as<IfStatement>();
47                     EliminateBracesFrom(ifStmt.ifTrue());
48                     EliminateBracesFrom(ifStmt.ifFalse());
49                     break;
50                 }
51                 case StatementKind::kFor: {
52                     ForStatement& forStmt = stmt->as<ForStatement>();
53                     EliminateBracesFrom(forStmt.statement());
54                     break;
55                 }
56                 case StatementKind::kDo: {
57                     DoStatement& doStmt = stmt->as<DoStatement>();
58                     EliminateBracesFrom(doStmt.statement());
59                     break;
60                 }
61                 default:
62                     break;
63             }
64 
65             // We always check the entire program.
66             return false;
67         }
68 
69         static void EliminateBracesFrom(std::unique_ptr<Statement>& stmt) {
70             if (!stmt || !stmt->is<Block>()) {
71                 return;
72             }
73             Block& block = stmt->as<Block>();
74             std::unique_ptr<Statement>* usefulStmt = nullptr;
75             for (std::unique_ptr<Statement>& childStmt : block.children()) {
76                 if (childStmt->isEmpty()) {
77                     continue;
78                 }
79                 if (usefulStmt) {
80                     // We found two non-empty statements. We can't eliminate braces from
81                     // this block.
82                     return;
83                 }
84                 // We found one non-empty statement.
85                 usefulStmt = &childStmt;
86             }
87 
88             if (!usefulStmt) {
89                 // This block held zero useful statements. Replace the block with a nop.
90                 stmt = Nop::Make();
91             } else {
92                 // This block held one useful statement. Replace the block with that statement.
93                 stmt = std::move(*usefulStmt);
94             }
95         }
96 
97         using INHERITED = ProgramWriter;
98     };
99 
100     class RequiredBraceWriter : public ProgramWriter {
101     public:
102         RequiredBraceWriter(const Context& ctx) : fContext(ctx) {}
103 
104         bool visitExpressionPtr(std::unique_ptr<Expression>& expr) override {
105             // We don't need to look inside expressions at all.
106             return false;
107         }
108 
109         bool visitStatementPtr(std::unique_ptr<Statement>& stmt) override {
110             // Look for the following structure:
111             //
112             //    if (...)
113             //      if (...)
114             //        any statement;
115             //    else
116             //      any statement;
117             //
118             // This structure isn't correct if we emit it textually, because the else-clause would
119             // be interpreted as if it were bound to the inner if-statement, like this:
120             //
121             //    if (...) {
122             //      if (...)
123             //        any statement;
124             //      else
125             //        any statement;
126             //    }
127             //
128             // If we find such a structure, we must disambiguate the else-clause by adding braces:
129             //    if (...) {
130             //      if (...)
131             //        any statement;
132             //    } else
133             //      any statement;
134 
135             // Work from the innermost blocks to the outermost.
136             INHERITED::visitStatementPtr(stmt);
137 
138             // We are looking for an if-statement.
139             if (stmt->is<IfStatement>()) {
140                 IfStatement& outer = stmt->as<IfStatement>();
141 
142                 // It should have an else clause, and directly wrap another if-statement (no Block).
143                 if (outer.ifFalse() && outer.ifTrue()->is<IfStatement>()) {
144                     const IfStatement& inner = outer.ifTrue()->as<IfStatement>();
145 
146                     // The inner if statement shouldn't have an else clause.
147                     if (!inner.ifFalse()) {
148                         // This structure is ambiguous; the else clause on the outer if-statement
149                         // will bind to the inner if-statement if we don't add braces. We must wrap
150                         // the outer if-statement's true-clause in braces.
151                         StatementArray blockStmts;
152                         blockStmts.push_back(std::move(outer.ifTrue()));
153                         Position stmtPosition = blockStmts.front()->position();
154                         std::unique_ptr<Statement> bracedIfTrue =
155                                 Block::MakeBlock(stmtPosition, std::move(blockStmts));
156                         stmt = IfStatement::Make(fContext,
157                                                  outer.position(),
158                                                  std::move(outer.test()),
159                                                  std::move(bracedIfTrue),
160                                                  std::move(outer.ifFalse()));
161                     }
162                 }
163             }
164 
165             // We always check the entire program.
166             return false;
167         }
168 
169         const Context& fContext;
170         using INHERITED = ProgramWriter;
171     };
172 
173     for (std::unique_ptr<ProgramElement>& pe : module.fElements) {
174         if (pe->is<FunctionDefinition>()) {
175             // First, we eliminate braces around single-statement child blocks wherever possible.
176             UnnecessaryBraceEliminator eliminator;
177             eliminator.visitStatementPtr(pe->as<FunctionDefinition>().body());
178 
179             // The first pass can be overzealous, since it can remove so many braces that else-
180             // clauses are bound to the wrong if-statement. Search for this case and fix it up
181             // if we find it.
182             RequiredBraceWriter writer(context);
183             writer.visitStatementPtr(pe->as<FunctionDefinition>().body());
184         }
185     }
186 }
187 
188 }  // namespace SkSL
189