xref: /aosp_15_r20/external/skia/src/sksl/ir/SkSLSwitchStatement.cpp (revision c8dee2aa9b3f27cf6c858bd81872bdeb2c07ed17)
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/ir/SkSLSwitchStatement.h"
9 
10 #include "include/core/SkTypes.h"
11 #include "include/private/base/SkTArray.h"
12 #include "include/private/base/SkTo.h"
13 #include "src/core/SkTHash.h"
14 #include "src/sksl/SkSLAnalysis.h"
15 #include "src/sksl/SkSLBuiltinTypes.h"
16 #include "src/sksl/SkSLConstantFolder.h"
17 #include "src/sksl/SkSLContext.h"
18 #include "src/sksl/SkSLErrorReporter.h"
19 #include "src/sksl/SkSLProgramSettings.h"
20 #include "src/sksl/ir/SkSLBlock.h"
21 #include "src/sksl/ir/SkSLBreakStatement.h"
22 #include "src/sksl/ir/SkSLNop.h"
23 #include "src/sksl/ir/SkSLSwitchCase.h"
24 #include "src/sksl/ir/SkSLSymbolTable.h"
25 #include "src/sksl/ir/SkSLType.h"
26 #include "src/sksl/transform/SkSLProgramWriter.h"
27 #include "src/sksl/transform/SkSLTransform.h"
28 
29 #include <algorithm>
30 #include <iterator>
31 
32 using namespace skia_private;
33 
34 namespace SkSL {
35 
description() const36 std::string SwitchStatement::description() const {
37     return "switch (" + this->value()->description() + ") " + this->caseBlock()->description();
38 }
39 
find_duplicate_case_values(const StatementArray & cases)40 static TArray<const SwitchCase*> find_duplicate_case_values(const StatementArray& cases) {
41     TArray<const SwitchCase*> duplicateCases;
42     THashSet<SKSL_INT> intValues;
43     bool foundDefault = false;
44 
45     for (const std::unique_ptr<Statement>& stmt : cases) {
46         const SwitchCase* sc = &stmt->as<SwitchCase>();
47         if (sc->isDefault()) {
48             if (foundDefault) {
49                 duplicateCases.push_back(sc);
50                 continue;
51             }
52             foundDefault = true;
53         } else {
54             SKSL_INT value = sc->value();
55             if (intValues.contains(value)) {
56                 duplicateCases.push_back(sc);
57                 continue;
58             }
59             intValues.add(value);
60         }
61     }
62 
63     return duplicateCases;
64 }
65 
remove_break_statements(std::unique_ptr<Statement> & stmt)66 static void remove_break_statements(std::unique_ptr<Statement>& stmt) {
67     class RemoveBreaksWriter : public ProgramWriter {
68     public:
69         bool visitStatementPtr(std::unique_ptr<Statement>& stmt) override {
70             if (stmt->is<BreakStatement>()) {
71                 stmt = Nop::Make();
72                 return false;
73             }
74             return ProgramWriter::visitStatementPtr(stmt);
75         }
76 
77         bool visitExpressionPtr(std::unique_ptr<Expression>& expr) override {
78             return false;
79         }
80     };
81     RemoveBreaksWriter{}.visitStatementPtr(stmt);
82 }
83 
block_for_case(Statement * caseBlock,SwitchCase * caseToCapture)84 static bool block_for_case(Statement* caseBlock, SwitchCase* caseToCapture) {
85     // This function reduces a switch to the matching case (or cases, if fallthrough occurs) when
86     // the switch-value is known and no conditional breaks exist. If conversion is not possible,
87     // false is returned and no changes are made. Conversion can fail if the switch contains
88     // conditional breaks.
89     //
90     // We have to be careful to not move any of the pointers until after we're sure we're going to
91     // succeed, so before we make any changes at all, we check the switch-cases to decide on a plan
92     // of action.
93     //
94     // First, we identify the code that would be run if the switch's value matches `caseToCapture`.
95     StatementArray& cases = caseBlock->as<Block>().children();
96     auto iter = cases.begin();
97     for (; iter != cases.end(); ++iter) {
98         const SwitchCase& sc = (*iter)->as<SwitchCase>();
99         if (&sc == caseToCapture) {
100             break;
101         }
102     }
103 
104     // Next, walk forward through the rest of the switch. If we find a conditional break, we're
105     // stuck and can't simplify at all. If we find an unconditional break, we have a range of
106     // statements that we can use for simplification.
107     auto startIter = iter;
108     bool removeBreakStatements = false;
109     for (; iter != cases.end(); ++iter) {
110         std::unique_ptr<Statement>& stmt = (*iter)->as<SwitchCase>().statement();
111         if (Analysis::SwitchCaseContainsConditionalExit(*stmt)) {
112             // We can't reduce switch-cases to a block when they have conditional exits.
113             return false;
114         }
115         if (Analysis::SwitchCaseContainsUnconditionalExit(*stmt)) {
116             // We found an unconditional exit. We can use this block, but we'll need to strip
117             // out the break statement if there is one.
118             removeBreakStatements = true;
119             ++iter;
120             break;
121         }
122     }
123 
124     // We fell off the bottom of the switch or encountered a break. Next, we must strip down
125     // `caseBlock` to hold only the statements needed to execute `caseToCapture`. To do this, we
126     // eliminate the SwitchCase elements. This converts each `case n: stmt;` element into just
127     // `stmt;`. While doing this, we also move the elements to the front of the array if they
128     // weren't already there.
129     int numElements = SkToInt(std::distance(startIter, iter));
130     for (int index = 0; index < numElements; ++index, ++startIter) {
131         cases[index] = std::move((*startIter)->as<SwitchCase>().statement());
132     }
133 
134     // Next, we shrink the statement array to destroy the excess statements.
135     cases.pop_back_n(cases.size() - numElements);
136 
137     // If we found an unconditional break at the end, we need to eliminate that break.
138     if (removeBreakStatements) {
139         remove_break_statements(cases.back());
140     }
141 
142     // We've stripped down `caseBlock` to contain only the captured case. Return true.
143     return true;
144 }
145 
Convert(const Context & context,Position pos,std::unique_ptr<Expression> value,ExpressionArray caseValues,StatementArray caseStatements,std::unique_ptr<SymbolTable> symbolTable)146 std::unique_ptr<Statement> SwitchStatement::Convert(const Context& context,
147                                                     Position pos,
148                                                     std::unique_ptr<Expression> value,
149                                                     ExpressionArray caseValues,
150                                                     StatementArray caseStatements,
151                                                     std::unique_ptr<SymbolTable> symbolTable) {
152     SkASSERT(caseValues.size() == caseStatements.size());
153 
154     value = context.fTypes.fInt->coerceExpression(std::move(value), context);
155     if (!value) {
156         return nullptr;
157     }
158 
159     StatementArray cases;
160     for (int i = 0; i < caseValues.size(); ++i) {
161         if (caseValues[i]) {
162             Position casePos = caseValues[i]->fPosition;
163             // Case values must be constant integers of the same type as the switch value
164             std::unique_ptr<Expression> caseValue = value->type().coerceExpression(
165                     std::move(caseValues[i]), context);
166             if (!caseValue) {
167                 return nullptr;
168             }
169             SKSL_INT intValue;
170             if (!ConstantFolder::GetConstantInt(*caseValue, &intValue)) {
171                 context.fErrors->error(casePos, "case value must be a constant integer");
172                 return nullptr;
173             }
174             cases.push_back(SwitchCase::Make(casePos, intValue, std::move(caseStatements[i])));
175         } else {
176             cases.push_back(SwitchCase::MakeDefault(pos, std::move(caseStatements[i])));
177         }
178     }
179 
180     // Detect duplicate `case` labels and report an error.
181     TArray<const SwitchCase*> duplicateCases = find_duplicate_case_values(cases);
182     if (!duplicateCases.empty()) {
183         for (const SwitchCase* sc : duplicateCases) {
184             if (sc->isDefault()) {
185                 context.fErrors->error(sc->fPosition, "duplicate default case");
186             } else {
187                 context.fErrors->error(sc->fPosition, "duplicate case value '" +
188                                                       std::to_string(sc->value()) + "'");
189             }
190         }
191         return nullptr;
192     }
193 
194     // If a switch-case has variable declarations at its top level, we want to create a scoped block
195     // around the switch, then move the variable declarations out of the switch body and into the
196     // outer scope. This prevents scoping issues in backends which don't offer a native switch.
197     // (skia:14375) It also allows static-switch optimization to work properly when variables are
198     // inherited from earlier fall-through cases. (oss-fuzz:70589)
199     std::unique_ptr<Block> block =
200             Transform::HoistSwitchVarDeclarationsAtTopLevel(context, cases, *symbolTable, pos);
201 
202     std::unique_ptr<Statement> switchStmt = SwitchStatement::Make(
203             context, pos, std::move(value),
204             Block::MakeBlock(pos, std::move(cases), Block::Kind::kBracedScope,
205                              std::move(symbolTable)));
206     if (block) {
207         // Add the switch statement to the end of the var-decl block.
208         block->children().push_back(std::move(switchStmt));
209         return block;
210     } else {
211         // Return the switch statement directly.
212         return switchStmt;
213     }
214 }
215 
Make(const Context & context,Position pos,std::unique_ptr<Expression> value,std::unique_ptr<Statement> caseBlock)216 std::unique_ptr<Statement> SwitchStatement::Make(const Context& context,
217                                                  Position pos,
218                                                  std::unique_ptr<Expression> value,
219                                                  std::unique_ptr<Statement> caseBlock) {
220     // Confirm that every statement in `cases` is a SwitchCase.
221     const StatementArray& cases = caseBlock->as<Block>().children();
222     SkASSERT(std::all_of(cases.begin(), cases.end(), [&](const std::unique_ptr<Statement>& stmt) {
223         return stmt->is<SwitchCase>();
224     }));
225 
226     // Confirm that every switch-case value is unique.
227     SkASSERT(find_duplicate_case_values(cases).empty());
228 
229     // Flatten switch statements if we're optimizing, and the value is known
230     if (context.fConfig->fSettings.fOptimize) {
231         SKSL_INT switchValue;
232         if (ConstantFolder::GetConstantInt(*value, &switchValue)) {
233             SwitchCase* defaultCase = nullptr;
234             SwitchCase* matchingCase = nullptr;
235             for (const std::unique_ptr<Statement>& stmt : cases) {
236                 SwitchCase& sc = stmt->as<SwitchCase>();
237                 if (sc.isDefault()) {
238                     defaultCase = &sc;
239                     continue;
240                 }
241 
242                 if (sc.value() == switchValue) {
243                     matchingCase = &sc;
244                     break;
245                 }
246             }
247 
248             if (!matchingCase) {
249                 // No case value matches the switch value.
250                 if (!defaultCase) {
251                     // No default switch-case exists; the switch had no effect. We can eliminate the
252                     // body of the switch entirely.
253                     // There's still value in preserving the symbol table here, particularly when
254                     // the input program is malformed, so we keep the Block itself. (oss-fuzz:70613)
255                     caseBlock->as<Block>().children().clear();
256                     return caseBlock;
257                 }
258                 // We had a default case; that's what we matched with.
259                 matchingCase = defaultCase;
260             }
261 
262             // Strip down our case block to contain only the matching case, if we can.
263             if (block_for_case(caseBlock.get(), matchingCase)) {
264                 return caseBlock;
265             }
266         }
267     }
268 
269     // The switch couldn't be optimized away; emit it normally.
270     return std::make_unique<SwitchStatement>(pos, std::move(value), std::move(caseBlock));
271 }
272 
273 }  // namespace SkSL
274