1 /*
2 * Copyright 2021 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/ir/SkSLSwitchStatement.h"
9
10 #include "include/core/SkTypes.h"
11 #include "include/private/base/SkTArray.h"
12 #include "include/private/base/SkTo.h"
13 #include "src/core/SkTHash.h"
14 #include "src/sksl/SkSLAnalysis.h"
15 #include "src/sksl/SkSLBuiltinTypes.h"
16 #include "src/sksl/SkSLConstantFolder.h"
17 #include "src/sksl/SkSLContext.h"
18 #include "src/sksl/SkSLErrorReporter.h"
19 #include "src/sksl/SkSLProgramSettings.h"
20 #include "src/sksl/ir/SkSLBlock.h"
21 #include "src/sksl/ir/SkSLBreakStatement.h"
22 #include "src/sksl/ir/SkSLNop.h"
23 #include "src/sksl/ir/SkSLSwitchCase.h"
24 #include "src/sksl/ir/SkSLSymbolTable.h"
25 #include "src/sksl/ir/SkSLType.h"
26 #include "src/sksl/transform/SkSLProgramWriter.h"
27 #include "src/sksl/transform/SkSLTransform.h"
28
29 #include <algorithm>
30 #include <iterator>
31
32 using namespace skia_private;
33
34 namespace SkSL {
35
description() const36 std::string SwitchStatement::description() const {
37 return "switch (" + this->value()->description() + ") " + this->caseBlock()->description();
38 }
39
find_duplicate_case_values(const StatementArray & cases)40 static TArray<const SwitchCase*> find_duplicate_case_values(const StatementArray& cases) {
41 TArray<const SwitchCase*> duplicateCases;
42 THashSet<SKSL_INT> intValues;
43 bool foundDefault = false;
44
45 for (const std::unique_ptr<Statement>& stmt : cases) {
46 const SwitchCase* sc = &stmt->as<SwitchCase>();
47 if (sc->isDefault()) {
48 if (foundDefault) {
49 duplicateCases.push_back(sc);
50 continue;
51 }
52 foundDefault = true;
53 } else {
54 SKSL_INT value = sc->value();
55 if (intValues.contains(value)) {
56 duplicateCases.push_back(sc);
57 continue;
58 }
59 intValues.add(value);
60 }
61 }
62
63 return duplicateCases;
64 }
65
remove_break_statements(std::unique_ptr<Statement> & stmt)66 static void remove_break_statements(std::unique_ptr<Statement>& stmt) {
67 class RemoveBreaksWriter : public ProgramWriter {
68 public:
69 bool visitStatementPtr(std::unique_ptr<Statement>& stmt) override {
70 if (stmt->is<BreakStatement>()) {
71 stmt = Nop::Make();
72 return false;
73 }
74 return ProgramWriter::visitStatementPtr(stmt);
75 }
76
77 bool visitExpressionPtr(std::unique_ptr<Expression>& expr) override {
78 return false;
79 }
80 };
81 RemoveBreaksWriter{}.visitStatementPtr(stmt);
82 }
83
block_for_case(Statement * caseBlock,SwitchCase * caseToCapture)84 static bool block_for_case(Statement* caseBlock, SwitchCase* caseToCapture) {
85 // This function reduces a switch to the matching case (or cases, if fallthrough occurs) when
86 // the switch-value is known and no conditional breaks exist. If conversion is not possible,
87 // false is returned and no changes are made. Conversion can fail if the switch contains
88 // conditional breaks.
89 //
90 // We have to be careful to not move any of the pointers until after we're sure we're going to
91 // succeed, so before we make any changes at all, we check the switch-cases to decide on a plan
92 // of action.
93 //
94 // First, we identify the code that would be run if the switch's value matches `caseToCapture`.
95 StatementArray& cases = caseBlock->as<Block>().children();
96 auto iter = cases.begin();
97 for (; iter != cases.end(); ++iter) {
98 const SwitchCase& sc = (*iter)->as<SwitchCase>();
99 if (&sc == caseToCapture) {
100 break;
101 }
102 }
103
104 // Next, walk forward through the rest of the switch. If we find a conditional break, we're
105 // stuck and can't simplify at all. If we find an unconditional break, we have a range of
106 // statements that we can use for simplification.
107 auto startIter = iter;
108 bool removeBreakStatements = false;
109 for (; iter != cases.end(); ++iter) {
110 std::unique_ptr<Statement>& stmt = (*iter)->as<SwitchCase>().statement();
111 if (Analysis::SwitchCaseContainsConditionalExit(*stmt)) {
112 // We can't reduce switch-cases to a block when they have conditional exits.
113 return false;
114 }
115 if (Analysis::SwitchCaseContainsUnconditionalExit(*stmt)) {
116 // We found an unconditional exit. We can use this block, but we'll need to strip
117 // out the break statement if there is one.
118 removeBreakStatements = true;
119 ++iter;
120 break;
121 }
122 }
123
124 // We fell off the bottom of the switch or encountered a break. Next, we must strip down
125 // `caseBlock` to hold only the statements needed to execute `caseToCapture`. To do this, we
126 // eliminate the SwitchCase elements. This converts each `case n: stmt;` element into just
127 // `stmt;`. While doing this, we also move the elements to the front of the array if they
128 // weren't already there.
129 int numElements = SkToInt(std::distance(startIter, iter));
130 for (int index = 0; index < numElements; ++index, ++startIter) {
131 cases[index] = std::move((*startIter)->as<SwitchCase>().statement());
132 }
133
134 // Next, we shrink the statement array to destroy the excess statements.
135 cases.pop_back_n(cases.size() - numElements);
136
137 // If we found an unconditional break at the end, we need to eliminate that break.
138 if (removeBreakStatements) {
139 remove_break_statements(cases.back());
140 }
141
142 // We've stripped down `caseBlock` to contain only the captured case. Return true.
143 return true;
144 }
145
Convert(const Context & context,Position pos,std::unique_ptr<Expression> value,ExpressionArray caseValues,StatementArray caseStatements,std::unique_ptr<SymbolTable> symbolTable)146 std::unique_ptr<Statement> SwitchStatement::Convert(const Context& context,
147 Position pos,
148 std::unique_ptr<Expression> value,
149 ExpressionArray caseValues,
150 StatementArray caseStatements,
151 std::unique_ptr<SymbolTable> symbolTable) {
152 SkASSERT(caseValues.size() == caseStatements.size());
153
154 value = context.fTypes.fInt->coerceExpression(std::move(value), context);
155 if (!value) {
156 return nullptr;
157 }
158
159 StatementArray cases;
160 for (int i = 0; i < caseValues.size(); ++i) {
161 if (caseValues[i]) {
162 Position casePos = caseValues[i]->fPosition;
163 // Case values must be constant integers of the same type as the switch value
164 std::unique_ptr<Expression> caseValue = value->type().coerceExpression(
165 std::move(caseValues[i]), context);
166 if (!caseValue) {
167 return nullptr;
168 }
169 SKSL_INT intValue;
170 if (!ConstantFolder::GetConstantInt(*caseValue, &intValue)) {
171 context.fErrors->error(casePos, "case value must be a constant integer");
172 return nullptr;
173 }
174 cases.push_back(SwitchCase::Make(casePos, intValue, std::move(caseStatements[i])));
175 } else {
176 cases.push_back(SwitchCase::MakeDefault(pos, std::move(caseStatements[i])));
177 }
178 }
179
180 // Detect duplicate `case` labels and report an error.
181 TArray<const SwitchCase*> duplicateCases = find_duplicate_case_values(cases);
182 if (!duplicateCases.empty()) {
183 for (const SwitchCase* sc : duplicateCases) {
184 if (sc->isDefault()) {
185 context.fErrors->error(sc->fPosition, "duplicate default case");
186 } else {
187 context.fErrors->error(sc->fPosition, "duplicate case value '" +
188 std::to_string(sc->value()) + "'");
189 }
190 }
191 return nullptr;
192 }
193
194 // If a switch-case has variable declarations at its top level, we want to create a scoped block
195 // around the switch, then move the variable declarations out of the switch body and into the
196 // outer scope. This prevents scoping issues in backends which don't offer a native switch.
197 // (skia:14375) It also allows static-switch optimization to work properly when variables are
198 // inherited from earlier fall-through cases. (oss-fuzz:70589)
199 std::unique_ptr<Block> block =
200 Transform::HoistSwitchVarDeclarationsAtTopLevel(context, cases, *symbolTable, pos);
201
202 std::unique_ptr<Statement> switchStmt = SwitchStatement::Make(
203 context, pos, std::move(value),
204 Block::MakeBlock(pos, std::move(cases), Block::Kind::kBracedScope,
205 std::move(symbolTable)));
206 if (block) {
207 // Add the switch statement to the end of the var-decl block.
208 block->children().push_back(std::move(switchStmt));
209 return block;
210 } else {
211 // Return the switch statement directly.
212 return switchStmt;
213 }
214 }
215
Make(const Context & context,Position pos,std::unique_ptr<Expression> value,std::unique_ptr<Statement> caseBlock)216 std::unique_ptr<Statement> SwitchStatement::Make(const Context& context,
217 Position pos,
218 std::unique_ptr<Expression> value,
219 std::unique_ptr<Statement> caseBlock) {
220 // Confirm that every statement in `cases` is a SwitchCase.
221 const StatementArray& cases = caseBlock->as<Block>().children();
222 SkASSERT(std::all_of(cases.begin(), cases.end(), [&](const std::unique_ptr<Statement>& stmt) {
223 return stmt->is<SwitchCase>();
224 }));
225
226 // Confirm that every switch-case value is unique.
227 SkASSERT(find_duplicate_case_values(cases).empty());
228
229 // Flatten switch statements if we're optimizing, and the value is known
230 if (context.fConfig->fSettings.fOptimize) {
231 SKSL_INT switchValue;
232 if (ConstantFolder::GetConstantInt(*value, &switchValue)) {
233 SwitchCase* defaultCase = nullptr;
234 SwitchCase* matchingCase = nullptr;
235 for (const std::unique_ptr<Statement>& stmt : cases) {
236 SwitchCase& sc = stmt->as<SwitchCase>();
237 if (sc.isDefault()) {
238 defaultCase = ≻
239 continue;
240 }
241
242 if (sc.value() == switchValue) {
243 matchingCase = ≻
244 break;
245 }
246 }
247
248 if (!matchingCase) {
249 // No case value matches the switch value.
250 if (!defaultCase) {
251 // No default switch-case exists; the switch had no effect. We can eliminate the
252 // body of the switch entirely.
253 // There's still value in preserving the symbol table here, particularly when
254 // the input program is malformed, so we keep the Block itself. (oss-fuzz:70613)
255 caseBlock->as<Block>().children().clear();
256 return caseBlock;
257 }
258 // We had a default case; that's what we matched with.
259 matchingCase = defaultCase;
260 }
261
262 // Strip down our case block to contain only the matching case, if we can.
263 if (block_for_case(caseBlock.get(), matchingCase)) {
264 return caseBlock;
265 }
266 }
267 }
268
269 // The switch couldn't be optimized away; emit it normally.
270 return std::make_unique<SwitchStatement>(pos, std::move(value), std::move(caseBlock));
271 }
272
273 } // namespace SkSL
274