xref: /aosp_15_r20/external/skia/src/sksl/SkSLConstantFolder.cpp (revision c8dee2aa9b3f27cf6c858bd81872bdeb2c07ed17)
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/SkSLConstantFolder.h"
9 
10 #include "include/core/SkTypes.h"
11 #include "include/private/base/SkFloatingPoint.h"
12 #include "include/private/base/SkTArray.h"
13 #include "src/sksl/SkSLAnalysis.h"
14 #include "src/sksl/SkSLContext.h"
15 #include "src/sksl/SkSLErrorReporter.h"
16 #include "src/sksl/SkSLPosition.h"
17 #include "src/sksl/SkSLProgramSettings.h"
18 #include "src/sksl/ir/SkSLBinaryExpression.h"
19 #include "src/sksl/ir/SkSLConstructorCompound.h"
20 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
21 #include "src/sksl/ir/SkSLConstructorSplat.h"
22 #include "src/sksl/ir/SkSLExpression.h"
23 #include "src/sksl/ir/SkSLLiteral.h"
24 #include "src/sksl/ir/SkSLModifierFlags.h"
25 #include "src/sksl/ir/SkSLPrefixExpression.h"
26 #include "src/sksl/ir/SkSLType.h"
27 #include "src/sksl/ir/SkSLVariable.h"
28 #include "src/sksl/ir/SkSLVariableReference.h"
29 
30 #include <cstdint>
31 #include <float.h>
32 #include <limits>
33 #include <optional>
34 #include <string>
35 #include <utility>
36 
37 using namespace skia_private;
38 
39 namespace SkSL {
40 
is_vec_or_mat(const Type & type)41 static bool is_vec_or_mat(const Type& type) {
42     switch (type.typeKind()) {
43         case Type::TypeKind::kMatrix:
44         case Type::TypeKind::kVector:
45             return true;
46 
47         default:
48             return false;
49     }
50 }
51 
eliminate_no_op_boolean(Position pos,const Expression & left,Operator op,const Expression & right)52 static std::unique_ptr<Expression> eliminate_no_op_boolean(Position pos,
53                                                            const Expression& left,
54                                                            Operator op,
55                                                            const Expression& right) {
56     bool rightVal = right.as<Literal>().boolValue();
57 
58     // Detect no-op Boolean expressions and optimize them away.
59     if ((op.kind() == Operator::Kind::LOGICALAND && rightVal)  ||  // (expr && true)  -> (expr)
60         (op.kind() == Operator::Kind::LOGICALOR  && !rightVal) ||  // (expr || false) -> (expr)
61         (op.kind() == Operator::Kind::LOGICALXOR && !rightVal) ||  // (expr ^^ false) -> (expr)
62         (op.kind() == Operator::Kind::EQEQ       && rightVal)  ||  // (expr == true)  -> (expr)
63         (op.kind() == Operator::Kind::NEQ        && !rightVal)) {  // (expr != false) -> (expr)
64 
65         return left.clone(pos);
66     }
67 
68     return nullptr;
69 }
70 
short_circuit_boolean(Position pos,const Expression & left,Operator op,const Expression & right)71 static std::unique_ptr<Expression> short_circuit_boolean(Position pos,
72                                                          const Expression& left,
73                                                          Operator op,
74                                                          const Expression& right) {
75     bool leftVal = left.as<Literal>().boolValue();
76 
77     // When the literal is on the left, we can sometimes eliminate the other expression entirely.
78     if ((op.kind() == Operator::Kind::LOGICALAND && !leftVal) ||  // (false && expr) -> (false)
79         (op.kind() == Operator::Kind::LOGICALOR  && leftVal)) {   // (true  || expr) -> (true)
80 
81         return left.clone(pos);
82     }
83 
84     // We can't eliminate the right-side expression via short-circuit, but we might still be able to
85     // simplify away a no-op expression.
86     return eliminate_no_op_boolean(pos, right, op, left);
87 }
88 
simplify_constant_equality(const Context & context,Position pos,const Expression & left,Operator op,const Expression & right)89 static std::unique_ptr<Expression> simplify_constant_equality(const Context& context,
90                                                               Position pos,
91                                                               const Expression& left,
92                                                               Operator op,
93                                                               const Expression& right) {
94     if (op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ) {
95         bool equality = (op.kind() == Operator::Kind::EQEQ);
96 
97         switch (left.compareConstant(right)) {
98             case Expression::ComparisonResult::kNotEqual:
99                 equality = !equality;
100                 [[fallthrough]];
101 
102             case Expression::ComparisonResult::kEqual:
103                 return Literal::MakeBool(context, pos, equality);
104 
105             case Expression::ComparisonResult::kUnknown:
106                 break;
107         }
108     }
109     return nullptr;
110 }
111 
simplify_matrix_multiplication(const Context & context,Position pos,const Expression & left,const Expression & right,int leftColumns,int leftRows,int rightColumns,int rightRows)112 static std::unique_ptr<Expression> simplify_matrix_multiplication(const Context& context,
113                                                                   Position pos,
114                                                                   const Expression& left,
115                                                                   const Expression& right,
116                                                                   int leftColumns,
117                                                                   int leftRows,
118                                                                   int rightColumns,
119                                                                   int rightRows) {
120     const Type& componentType = left.type().componentType();
121     SkASSERT(componentType.matches(right.type().componentType()));
122 
123     // Fetch the left matrix.
124     double leftVals[4][4];
125     for (int c = 0; c < leftColumns; ++c) {
126         for (int r = 0; r < leftRows; ++r) {
127             leftVals[c][r] = *left.getConstantValue((c * leftRows) + r);
128         }
129     }
130     // Fetch the right matrix.
131     double rightVals[4][4];
132     for (int c = 0; c < rightColumns; ++c) {
133         for (int r = 0; r < rightRows; ++r) {
134             rightVals[c][r] = *right.getConstantValue((c * rightRows) + r);
135         }
136     }
137 
138     SkASSERT(leftColumns == rightRows);
139     int outColumns   = rightColumns,
140         outRows      = leftRows;
141 
142     double args[16];
143     int argIndex = 0;
144     for (int c = 0; c < outColumns; ++c) {
145         for (int r = 0; r < outRows; ++r) {
146             // Compute a dot product for this position.
147             double val = 0;
148             for (int dotIdx = 0; dotIdx < leftColumns; ++dotIdx) {
149                 val += leftVals[dotIdx][r] * rightVals[c][dotIdx];
150             }
151 
152             if (val >= -FLT_MAX && val <= FLT_MAX) {
153                 args[argIndex++] = val;
154             } else {
155                 // The value is outside the 32-bit float range, or is NaN; do not optimize.
156                 return nullptr;
157             }
158         }
159     }
160 
161     if (outColumns == 1) {
162         // Matrix-times-vector conceptually makes a 1-column N-row matrix, but we return vecN.
163         std::swap(outColumns, outRows);
164     }
165 
166     const Type& resultType = componentType.toCompound(context, outColumns, outRows);
167     return ConstructorCompound::MakeFromConstants(context, pos, resultType, args);
168 }
169 
simplify_matrix_times_matrix(const Context & context,Position pos,const Expression & left,const Expression & right)170 static std::unique_ptr<Expression> simplify_matrix_times_matrix(const Context& context,
171                                                                 Position pos,
172                                                                 const Expression& left,
173                                                                 const Expression& right) {
174     const Type& leftType = left.type();
175     const Type& rightType = right.type();
176 
177     SkASSERT(leftType.isMatrix());
178     SkASSERT(rightType.isMatrix());
179 
180     return simplify_matrix_multiplication(context, pos, left, right,
181                                           leftType.columns(), leftType.rows(),
182                                           rightType.columns(), rightType.rows());
183 }
184 
simplify_vector_times_matrix(const Context & context,Position pos,const Expression & left,const Expression & right)185 static std::unique_ptr<Expression> simplify_vector_times_matrix(const Context& context,
186                                                                 Position pos,
187                                                                 const Expression& left,
188                                                                 const Expression& right) {
189     const Type& leftType = left.type();
190     const Type& rightType = right.type();
191 
192     SkASSERT(leftType.isVector());
193     SkASSERT(rightType.isMatrix());
194 
195     return simplify_matrix_multiplication(context, pos, left, right,
196                                           /*leftColumns=*/leftType.columns(), /*leftRows=*/1,
197                                           rightType.columns(), rightType.rows());
198 }
199 
simplify_matrix_times_vector(const Context & context,Position pos,const Expression & left,const Expression & right)200 static std::unique_ptr<Expression> simplify_matrix_times_vector(const Context& context,
201                                                                 Position pos,
202                                                                 const Expression& left,
203                                                                 const Expression& right) {
204     const Type& leftType = left.type();
205     const Type& rightType = right.type();
206 
207     SkASSERT(leftType.isMatrix());
208     SkASSERT(rightType.isVector());
209 
210     return simplify_matrix_multiplication(context, pos, left, right,
211                                           leftType.columns(), leftType.rows(),
212                                           /*rightColumns=*/1, /*rightRows=*/rightType.columns());
213 }
214 
simplify_componentwise(const Context & context,Position pos,const Expression & left,Operator op,const Expression & right)215 static std::unique_ptr<Expression> simplify_componentwise(const Context& context,
216                                                           Position pos,
217                                                           const Expression& left,
218                                                           Operator op,
219                                                           const Expression& right) {
220     SkASSERT(is_vec_or_mat(left.type()));
221     SkASSERT(left.type().matches(right.type()));
222     const Type& type = left.type();
223 
224     // Handle equality operations: == !=
225     if (std::unique_ptr<Expression> result = simplify_constant_equality(context, pos, left, op,
226             right)) {
227         return result;
228     }
229 
230     // Handle floating-point arithmetic: + - * /
231     using FoldFn = double (*)(double, double);
232     FoldFn foldFn;
233     switch (op.kind()) {
234         case Operator::Kind::PLUS:  foldFn = +[](double a, double b) { return a + b; }; break;
235         case Operator::Kind::MINUS: foldFn = +[](double a, double b) { return a - b; }; break;
236         case Operator::Kind::STAR:  foldFn = +[](double a, double b) { return a * b; }; break;
237         case Operator::Kind::SLASH: foldFn = +[](double a, double b) { return a / b; }; break;
238         default:
239             return nullptr;
240     }
241 
242     const Type& componentType = type.componentType();
243     SkASSERT(componentType.isNumber());
244 
245     double minimumValue = componentType.minimumValue();
246     double maximumValue = componentType.maximumValue();
247 
248     double args[16];
249     int numSlots = type.slotCount();
250     for (int i = 0; i < numSlots; i++) {
251         double value = foldFn(*left.getConstantValue(i), *right.getConstantValue(i));
252         if (value < minimumValue || value > maximumValue) {
253             return nullptr;
254         }
255         args[i] = value;
256     }
257     return ConstructorCompound::MakeFromConstants(context, pos, type, args);
258 }
259 
splat_scalar(const Context & context,const Expression & scalar,const Type & type)260 static std::unique_ptr<Expression> splat_scalar(const Context& context,
261                                                 const Expression& scalar,
262                                                 const Type& type) {
263     if (type.isVector()) {
264         return ConstructorSplat::Make(context, scalar.fPosition, type, scalar.clone());
265     }
266     if (type.isMatrix()) {
267         int numSlots = type.slotCount();
268         ExpressionArray splatMatrix;
269         splatMatrix.reserve_exact(numSlots);
270         for (int index = 0; index < numSlots; ++index) {
271             splatMatrix.push_back(scalar.clone());
272         }
273         return ConstructorCompound::Make(context, scalar.fPosition, type, std::move(splatMatrix));
274     }
275     SkDEBUGFAILF("unsupported type %s", type.description().c_str());
276     return nullptr;
277 }
278 
cast_expression(const Context & context,Position pos,const Expression & expr,const Type & type)279 static std::unique_ptr<Expression> cast_expression(const Context& context,
280                                                    Position pos,
281                                                    const Expression& expr,
282                                                    const Type& type) {
283     SkASSERT(type.componentType().matches(expr.type().componentType()));
284     if (expr.type().isScalar()) {
285         if (type.isMatrix()) {
286             return ConstructorDiagonalMatrix::Make(context, pos, type, expr.clone());
287         }
288         if (type.isVector()) {
289             return ConstructorSplat::Make(context, pos, type, expr.clone());
290         }
291     }
292     if (type.matches(expr.type())) {
293         return expr.clone(pos);
294     }
295     // We can't cast matrices into vectors or vice-versa.
296     return nullptr;
297 }
298 
zero_expression(const Context & context,Position pos,const Type & type)299 static std::unique_ptr<Expression> zero_expression(const Context& context,
300                                                    Position pos,
301                                                    const Type& type) {
302     std::unique_ptr<Expression> zero = Literal::Make(pos, 0.0, &type.componentType());
303     if (type.isScalar()) {
304         return zero;
305     }
306     if (type.isVector()) {
307         return ConstructorSplat::Make(context, pos, type, std::move(zero));
308     }
309     if (type.isMatrix()) {
310         return ConstructorDiagonalMatrix::Make(context, pos, type, std::move(zero));
311     }
312     SkDEBUGFAILF("unsupported type %s", type.description().c_str());
313     return nullptr;
314 }
315 
negate_expression(const Context & context,Position pos,const Expression & expr,const Type & type)316 static std::unique_ptr<Expression> negate_expression(const Context& context,
317                                                      Position pos,
318                                                      const Expression& expr,
319                                                      const Type& type) {
320     std::unique_ptr<Expression> ctor = cast_expression(context, pos, expr, type);
321     return ctor ? PrefixExpression::Make(context, pos, Operator::Kind::MINUS, std::move(ctor))
322                 : nullptr;
323 }
324 
GetConstantInt(const Expression & value,SKSL_INT * out)325 bool ConstantFolder::GetConstantInt(const Expression& value, SKSL_INT* out) {
326     const Expression* expr = GetConstantValueForVariable(value);
327     if (!expr->isIntLiteral()) {
328         return false;
329     }
330     *out = expr->as<Literal>().intValue();
331     return true;
332 }
333 
GetConstantValue(const Expression & value,double * out)334 bool ConstantFolder::GetConstantValue(const Expression& value, double* out) {
335     const Expression* expr = GetConstantValueForVariable(value);
336     if (!expr->is<Literal>()) {
337         return false;
338     }
339     *out = expr->as<Literal>().value();
340     return true;
341 }
342 
contains_constant_zero(const Expression & expr)343 static bool contains_constant_zero(const Expression& expr) {
344     int numSlots = expr.type().slotCount();
345     for (int index = 0; index < numSlots; ++index) {
346         std::optional<double> slotVal = expr.getConstantValue(index);
347         if (slotVal.has_value() && *slotVal == 0.0) {
348             return true;
349         }
350     }
351     return false;
352 }
353 
IsConstantSplat(const Expression & expr,double value)354 bool ConstantFolder::IsConstantSplat(const Expression& expr, double value) {
355     int numSlots = expr.type().slotCount();
356     for (int index = 0; index < numSlots; ++index) {
357         std::optional<double> slotVal = expr.getConstantValue(index);
358         if (!slotVal.has_value() || *slotVal != value) {
359             return false;
360         }
361     }
362     return true;
363 }
364 
365 // Returns true if the expression is a square diagonal matrix containing `value`.
is_constant_diagonal(const Expression & expr,double value)366 static bool is_constant_diagonal(const Expression& expr, double value) {
367     SkASSERT(expr.type().isMatrix());
368     int columns = expr.type().columns();
369     int rows = expr.type().rows();
370     if (columns != rows) {
371         return false;
372     }
373     int slotIdx = 0;
374     for (int c = 0; c < columns; ++c) {
375         for (int r = 0; r < rows; ++r) {
376             double expectation = (c == r) ? value : 0;
377             std::optional<double> slotVal = expr.getConstantValue(slotIdx++);
378             if (!slotVal.has_value() || *slotVal != expectation) {
379                 return false;
380             }
381         }
382     }
383     return true;
384 }
385 
386 // Returns true if the expression is a scalar, vector, or diagonal matrix containing `value`.
is_constant_value(const Expression & expr,double value)387 static bool is_constant_value(const Expression& expr, double value) {
388     return expr.type().isMatrix() ? is_constant_diagonal(expr, value)
389                                   : ConstantFolder::IsConstantSplat(expr, value);
390 }
391 
392 // The expression represents the right-hand side of a division op. If the division can be
393 // strength-reduced into multiplication by a reciprocal, returns that reciprocal as an expression.
394 // Note that this only supports literal values with safe-to-use reciprocals, and returns null if
395 // Expression contains anything else.
make_reciprocal_expression(const Context & context,const Expression & right)396 static std::unique_ptr<Expression> make_reciprocal_expression(const Context& context,
397                                                               const Expression& right) {
398     if (right.type().isMatrix() || !right.type().componentType().isFloat()) {
399         return nullptr;
400     }
401     // Verify that each slot contains a finite, non-zero literal, take its reciprocal.
402     double values[4];
403     int nslots = right.type().slotCount();
404     for (int index = 0; index < nslots; ++index) {
405         std::optional<double> value = right.getConstantValue(index);
406         if (!value) {
407             return nullptr;
408         }
409         *value = sk_ieee_double_divide(1.0, *value);
410         if (*value >= -FLT_MAX && *value <= FLT_MAX && *value != 0.0) {
411             // The reciprocal can be represented safely as a finite 32-bit float.
412             values[index] = *value;
413         } else {
414             // The value is outside the 32-bit float range, or is NaN; do not optimize.
415             return nullptr;
416         }
417     }
418     // Turn the expression array into a compound constructor. (If this is a single-slot expression,
419     // this will return the literal as-is.)
420     return ConstructorCompound::MakeFromConstants(context, right.fPosition, right.type(), values);
421 }
422 
error_on_divide_by_zero(const Context & context,Position pos,Operator op,const Expression & right)423 static bool error_on_divide_by_zero(const Context& context, Position pos, Operator op,
424                                     const Expression& right) {
425     switch (op.kind()) {
426         case Operator::Kind::SLASH:
427         case Operator::Kind::SLASHEQ:
428         case Operator::Kind::PERCENT:
429         case Operator::Kind::PERCENTEQ:
430             if (contains_constant_zero(right)) {
431                 context.fErrors->error(pos, "division by zero");
432                 return true;
433             }
434             return false;
435         default:
436             return false;
437     }
438 }
439 
GetConstantValueOrNull(const Expression & inExpr)440 const Expression* ConstantFolder::GetConstantValueOrNull(const Expression& inExpr) {
441     const Expression* expr = &inExpr;
442     while (expr->is<VariableReference>()) {
443         const VariableReference& varRef = expr->as<VariableReference>();
444         if (varRef.refKind() != VariableRefKind::kRead) {
445             return nullptr;
446         }
447         const Variable& var = *varRef.variable();
448         if (!var.modifierFlags().isConst()) {
449             return nullptr;
450         }
451         expr = var.initialValue();
452         if (!expr) {
453             // Generally, const variables must have initial values. However, function parameters are
454             // an exception; they can be const but won't have an initial value.
455             return nullptr;
456         }
457     }
458     return Analysis::IsCompileTimeConstant(*expr) ? expr : nullptr;
459 }
460 
GetConstantValueForVariable(const Expression & inExpr)461 const Expression* ConstantFolder::GetConstantValueForVariable(const Expression& inExpr) {
462     const Expression* expr = GetConstantValueOrNull(inExpr);
463     return expr ? expr : &inExpr;
464 }
465 
MakeConstantValueForVariable(Position pos,std::unique_ptr<Expression> inExpr)466 std::unique_ptr<Expression> ConstantFolder::MakeConstantValueForVariable(
467         Position pos, std::unique_ptr<Expression> inExpr) {
468     const Expression* expr = GetConstantValueOrNull(*inExpr);
469     return expr ? expr->clone(pos) : std::move(inExpr);
470 }
471 
is_scalar_op_matrix(const Expression & left,const Expression & right)472 static bool is_scalar_op_matrix(const Expression& left, const Expression& right) {
473     return left.type().isScalar() && right.type().isMatrix();
474 }
475 
is_matrix_op_scalar(const Expression & left,const Expression & right)476 static bool is_matrix_op_scalar(const Expression& left, const Expression& right) {
477     return is_scalar_op_matrix(right, left);
478 }
479 
simplify_arithmetic(const Context & context,Position pos,const Expression & left,Operator op,const Expression & right,const Type & resultType)480 static std::unique_ptr<Expression> simplify_arithmetic(const Context& context,
481                                                        Position pos,
482                                                        const Expression& left,
483                                                        Operator op,
484                                                        const Expression& right,
485                                                        const Type& resultType) {
486     switch (op.kind()) {
487         case Operator::Kind::PLUS:
488             if (!is_scalar_op_matrix(left, right) &&
489                 ConstantFolder::IsConstantSplat(right, 0.0)) {  // x + 0
490                 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
491                                                                        resultType)) {
492                     return expr;
493                 }
494             }
495             if (!is_matrix_op_scalar(left, right) &&
496                 ConstantFolder::IsConstantSplat(left, 0.0)) {  // 0 + x
497                 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, right,
498                                                                        resultType)) {
499                     return expr;
500                 }
501             }
502             break;
503 
504         case Operator::Kind::STAR:
505             if (is_constant_value(right, 1.0)) {  // x * 1
506                 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
507                                                                        resultType)) {
508                     return expr;
509                 }
510             }
511             if (is_constant_value(left, 1.0)) {   // 1 * x
512                 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, right,
513                                                                        resultType)) {
514                     return expr;
515                 }
516             }
517             if (is_constant_value(right, 0.0) && !Analysis::HasSideEffects(left)) {  // x * 0
518                 return zero_expression(context, pos, resultType);
519             }
520             if (is_constant_value(left, 0.0) && !Analysis::HasSideEffects(right)) {  // 0 * x
521                 return zero_expression(context, pos, resultType);
522             }
523             if (is_constant_value(right, -1.0)) {  // x * -1 (to `-x`)
524                 if (std::unique_ptr<Expression> expr = negate_expression(context, pos, left,
525                                                                          resultType)) {
526                     return expr;
527                 }
528             }
529             if (is_constant_value(left, -1.0)) {  // -1 * x (to `-x`)
530                 if (std::unique_ptr<Expression> expr = negate_expression(context, pos, right,
531                                                                          resultType)) {
532                     return expr;
533                 }
534             }
535             break;
536 
537         case Operator::Kind::MINUS:
538             if (!is_scalar_op_matrix(left, right) &&
539                 ConstantFolder::IsConstantSplat(right, 0.0)) {  // x - 0
540                 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
541                                                                        resultType)) {
542                     return expr;
543                 }
544             }
545             if (!is_matrix_op_scalar(left, right) &&
546                 ConstantFolder::IsConstantSplat(left, 0.0)) {  // 0 - x
547                 if (std::unique_ptr<Expression> expr = negate_expression(context, pos, right,
548                                                                          resultType)) {
549                     return expr;
550                 }
551             }
552             break;
553 
554         case Operator::Kind::SLASH:
555             if (!is_scalar_op_matrix(left, right) &&
556                 ConstantFolder::IsConstantSplat(right, 1.0)) {  // x / 1
557                 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
558                                                                        resultType)) {
559                     return expr;
560                 }
561             }
562             if (!left.type().isMatrix()) {  // convert `x / 2` into `x * 0.5`
563                 if (std::unique_ptr<Expression> expr = make_reciprocal_expression(context, right)) {
564                     return BinaryExpression::Make(context, pos, left.clone(), Operator::Kind::STAR,
565                                                   std::move(expr));
566                 }
567             }
568             break;
569 
570         case Operator::Kind::PLUSEQ:
571         case Operator::Kind::MINUSEQ:
572             if (ConstantFolder::IsConstantSplat(right, 0.0)) {  // x += 0, x -= 0
573                 if (std::unique_ptr<Expression> var = cast_expression(context, pos, left,
574                                                                       resultType)) {
575                     Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead);
576                     return var;
577                 }
578             }
579             break;
580 
581         case Operator::Kind::STAREQ:
582             if (is_constant_value(right, 1.0)) {  // x *= 1
583                 if (std::unique_ptr<Expression> var = cast_expression(context, pos, left,
584                                                                       resultType)) {
585                     Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead);
586                     return var;
587                 }
588             }
589             break;
590 
591         case Operator::Kind::SLASHEQ:
592             if (ConstantFolder::IsConstantSplat(right, 1.0)) {  // x /= 1
593                 if (std::unique_ptr<Expression> var = cast_expression(context, pos, left,
594                                                                       resultType)) {
595                     Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead);
596                     return var;
597                 }
598             }
599             if (std::unique_ptr<Expression> expr = make_reciprocal_expression(context, right)) {
600                 return BinaryExpression::Make(context, pos, left.clone(), Operator::Kind::STAREQ,
601                                               std::move(expr));
602             }
603             break;
604 
605         default:
606             break;
607     }
608 
609     return nullptr;
610 }
611 
612 // The expression must be scalar, and represents the right-hand side of a division op. It can
613 // contain anything, not just literal values. This returns the binary expression `1.0 / expr`. The
614 // expression might be further simplified by the constant folding, if possible.
one_over_scalar(const Context & context,const Expression & right)615 static std::unique_ptr<Expression> one_over_scalar(const Context& context,
616                                                    const Expression& right) {
617     SkASSERT(right.type().isScalar());
618     Position pos = right.fPosition;
619     return BinaryExpression::Make(context, pos,
620                                   Literal::Make(pos, 1.0, &right.type()),
621                                   Operator::Kind::SLASH,
622                                   right.clone());
623 }
624 
simplify_matrix_division(const Context & context,Position pos,const Expression & left,Operator op,const Expression & right,const Type & resultType)625 static std::unique_ptr<Expression> simplify_matrix_division(const Context& context,
626                                                             Position pos,
627                                                             const Expression& left,
628                                                             Operator op,
629                                                             const Expression& right,
630                                                             const Type& resultType) {
631     // Convert matrix-over-scalar `x /= y` into `x *= (1.0 / y)`. This generates better
632     // code in SPIR-V and Metal, and should be roughly equivalent elsewhere.
633     switch (op.kind()) {
634         case OperatorKind::SLASH:
635         case OperatorKind::SLASHEQ:
636             if (left.type().isMatrix() && right.type().isScalar()) {
637                 Operator multiplyOp = op.isAssignment() ? OperatorKind::STAREQ
638                                                         : OperatorKind::STAR;
639                 return BinaryExpression::Make(context, pos,
640                                               left.clone(),
641                                               multiplyOp,
642                                               one_over_scalar(context, right));
643             }
644             break;
645 
646         default:
647             break;
648     }
649 
650     return nullptr;
651 }
652 
fold_expression(Position pos,double result,const Type * resultType)653 static std::unique_ptr<Expression> fold_expression(Position pos,
654                                                    double result,
655                                                    const Type* resultType) {
656     if (resultType->isNumber()) {
657         if (result >= resultType->minimumValue() && result <= resultType->maximumValue()) {
658             // This result will fit inside its type.
659         } else {
660             // The value is outside the range or is NaN (all if-checks fail); do not optimize.
661             return nullptr;
662         }
663     }
664 
665     return Literal::Make(pos, result, resultType);
666 }
667 
fold_two_constants(const Context & context,Position pos,const Expression * left,Operator op,const Expression * right,const Type & resultType)668 static std::unique_ptr<Expression> fold_two_constants(const Context& context,
669                                                       Position pos,
670                                                       const Expression* left,
671                                                       Operator op,
672                                                       const Expression* right,
673                                                       const Type& resultType) {
674     SkASSERT(Analysis::IsCompileTimeConstant(*left));
675     SkASSERT(Analysis::IsCompileTimeConstant(*right));
676     const Type& leftType = left->type();
677     const Type& rightType = right->type();
678 
679     // Handle pairs of integer literals.
680     if (left->isIntLiteral() && right->isIntLiteral()) {
681         using SKSL_UINT = uint64_t;
682         SKSL_INT leftVal  = left->as<Literal>().intValue();
683         SKSL_INT rightVal = right->as<Literal>().intValue();
684 
685         // Note that fold_expression returns null if the result would overflow its type.
686         #define RESULT(Op)   fold_expression(pos, (SKSL_INT)(leftVal) Op \
687                                                   (SKSL_INT)(rightVal), &resultType)
688         #define URESULT(Op)  fold_expression(pos, (SKSL_INT)((SKSL_UINT)(leftVal) Op \
689                                                   (SKSL_UINT)(rightVal)), &resultType)
690         switch (op.kind()) {
691             case Operator::Kind::PLUS:       return URESULT(+);
692             case Operator::Kind::MINUS:      return URESULT(-);
693             case Operator::Kind::STAR:       return URESULT(*);
694             case Operator::Kind::SLASH:
695                 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
696                     context.fErrors->error(pos, "arithmetic overflow");
697                     return nullptr;
698                 }
699                 return RESULT(/);
700 
701             case Operator::Kind::PERCENT:
702                 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
703                     context.fErrors->error(pos, "arithmetic overflow");
704                     return nullptr;
705                 }
706                 return RESULT(%);
707 
708             case Operator::Kind::BITWISEAND: return RESULT(&);
709             case Operator::Kind::BITWISEOR:  return RESULT(|);
710             case Operator::Kind::BITWISEXOR: return RESULT(^);
711             case Operator::Kind::EQEQ:       return RESULT(==);
712             case Operator::Kind::NEQ:        return RESULT(!=);
713             case Operator::Kind::GT:         return RESULT(>);
714             case Operator::Kind::GTEQ:       return RESULT(>=);
715             case Operator::Kind::LT:         return RESULT(<);
716             case Operator::Kind::LTEQ:       return RESULT(<=);
717             case Operator::Kind::SHL:
718                 if (rightVal >= 0 && rightVal <= 31) {
719                     // Left-shifting a negative (or really, any signed) value is undefined behavior
720                     // in C++, but not in GLSL. Do the shift on unsigned values to avoid triggering
721                     // an UBSAN error.
722                     return URESULT(<<);
723                 }
724                 context.fErrors->error(pos, "shift value out of range");
725                 return nullptr;
726 
727             case Operator::Kind::SHR:
728                 if (rightVal >= 0 && rightVal <= 31) {
729                     return RESULT(>>);
730                 }
731                 context.fErrors->error(pos, "shift value out of range");
732                 return nullptr;
733 
734             default:
735                 break;
736         }
737         #undef RESULT
738         #undef URESULT
739 
740         return nullptr;
741     }
742 
743     // Handle pairs of floating-point literals.
744     if (left->isFloatLiteral() && right->isFloatLiteral()) {
745         SKSL_FLOAT leftVal  = left->as<Literal>().floatValue();
746         SKSL_FLOAT rightVal = right->as<Literal>().floatValue();
747 
748         #define RESULT(Op) fold_expression(pos, leftVal Op rightVal, &resultType)
749         switch (op.kind()) {
750             case Operator::Kind::PLUS:  return RESULT(+);
751             case Operator::Kind::MINUS: return RESULT(-);
752             case Operator::Kind::STAR:  return RESULT(*);
753             case Operator::Kind::SLASH: return RESULT(/);
754             case Operator::Kind::EQEQ:  return RESULT(==);
755             case Operator::Kind::NEQ:   return RESULT(!=);
756             case Operator::Kind::GT:    return RESULT(>);
757             case Operator::Kind::GTEQ:  return RESULT(>=);
758             case Operator::Kind::LT:    return RESULT(<);
759             case Operator::Kind::LTEQ:  return RESULT(<=);
760             default:                    break;
761         }
762         #undef RESULT
763 
764         return nullptr;
765     }
766 
767     // Perform matrix multiplication.
768     if (op.kind() == Operator::Kind::STAR) {
769         if (leftType.isMatrix() && rightType.isMatrix()) {
770             return simplify_matrix_times_matrix(context, pos, *left, *right);
771         }
772         if (leftType.isVector() && rightType.isMatrix()) {
773             return simplify_vector_times_matrix(context, pos, *left, *right);
774         }
775         if (leftType.isMatrix() && rightType.isVector()) {
776             return simplify_matrix_times_vector(context, pos, *left, *right);
777         }
778     }
779 
780     // Perform constant folding on pairs of vectors/matrices.
781     if (is_vec_or_mat(leftType) && leftType.matches(rightType)) {
782         return simplify_componentwise(context, pos, *left, op, *right);
783     }
784 
785     // Perform constant folding on vectors/matrices against scalars, e.g.: half4(2) + 2
786     if (rightType.isScalar() && is_vec_or_mat(leftType) &&
787         leftType.componentType().matches(rightType)) {
788         return simplify_componentwise(context, pos,
789                                       *left, op, *splat_scalar(context, *right, left->type()));
790     }
791 
792     // Perform constant folding on scalars against vectors/matrices, e.g.: 2 + half4(2)
793     if (leftType.isScalar() && is_vec_or_mat(rightType) &&
794         rightType.componentType().matches(leftType)) {
795         return simplify_componentwise(context, pos,
796                                       *splat_scalar(context, *left, right->type()), op, *right);
797     }
798 
799     // Perform constant folding on pairs of matrices, arrays or structs.
800     if ((leftType.isMatrix() && rightType.isMatrix()) ||
801         (leftType.isArray() && rightType.isArray()) ||
802         (leftType.isStruct() && rightType.isStruct())) {
803         return simplify_constant_equality(context, pos, *left, op, *right);
804     }
805 
806     // We aren't able to constant-fold these expressions.
807     return nullptr;
808 }
809 
Simplify(const Context & context,Position pos,const Expression & leftExpr,Operator op,const Expression & rightExpr,const Type & resultType)810 std::unique_ptr<Expression> ConstantFolder::Simplify(const Context& context,
811                                                      Position pos,
812                                                      const Expression& leftExpr,
813                                                      Operator op,
814                                                      const Expression& rightExpr,
815                                                      const Type& resultType) {
816     // Replace constant variables with their literal values.
817     const Expression* left = GetConstantValueForVariable(leftExpr);
818     const Expression* right = GetConstantValueForVariable(rightExpr);
819 
820     // If this is the assignment operator, and both sides are the same trivial expression, this is
821     // self-assignment (i.e., `var = var`) and can be reduced to just a variable reference (`var`).
822     // This can happen when other parts of the assignment are optimized away.
823     if (op.kind() == Operator::Kind::EQ && Analysis::IsSameExpressionTree(*left, *right)) {
824         return right->clone(pos);
825     }
826 
827     // Simplify the expression when both sides are constant Boolean literals.
828     if (left->isBoolLiteral() && right->isBoolLiteral()) {
829         bool leftVal  = left->as<Literal>().boolValue();
830         bool rightVal = right->as<Literal>().boolValue();
831         bool result;
832         switch (op.kind()) {
833             case Operator::Kind::LOGICALAND: result = leftVal && rightVal; break;
834             case Operator::Kind::LOGICALOR:  result = leftVal || rightVal; break;
835             case Operator::Kind::LOGICALXOR: result = leftVal ^  rightVal; break;
836             case Operator::Kind::EQEQ:       result = leftVal == rightVal; break;
837             case Operator::Kind::NEQ:        result = leftVal != rightVal; break;
838             default: return nullptr;
839         }
840         return Literal::MakeBool(context, pos, result);
841     }
842 
843     // If the left side is a Boolean literal, apply short-circuit optimizations.
844     if (left->isBoolLiteral()) {
845         return short_circuit_boolean(pos, *left, op, *right);
846     }
847 
848     // If the right side is a Boolean literal...
849     if (right->isBoolLiteral()) {
850         // ... and the left side has no side effects...
851         if (!Analysis::HasSideEffects(*left)) {
852             // We can reverse the expressions and short-circuit optimizations are still valid.
853             return short_circuit_boolean(pos, *right, op, *left);
854         }
855 
856         // We can't use short-circuiting, but we can still optimize away no-op Boolean expressions.
857         return eliminate_no_op_boolean(pos, *left, op, *right);
858     }
859 
860     if (op.kind() == Operator::Kind::EQEQ && Analysis::IsSameExpressionTree(*left, *right)) {
861         // With == comparison, if both sides are the same trivial expression, this is self-
862         // comparison and is always true. (We are not concerned with NaN.)
863         return Literal::MakeBool(context, pos, /*value=*/true);
864     }
865 
866     if (op.kind() == Operator::Kind::NEQ && Analysis::IsSameExpressionTree(*left, *right)) {
867         // With != comparison, if both sides are the same trivial expression, this is self-
868         // comparison and is always false. (We are not concerned with NaN.)
869         return Literal::MakeBool(context, pos, /*value=*/false);
870     }
871 
872     if (error_on_divide_by_zero(context, pos, op, *right)) {
873         return nullptr;
874     }
875 
876     // Perform full constant folding when both sides are compile-time constants.
877     bool leftSideIsConstant = Analysis::IsCompileTimeConstant(*left);
878     bool rightSideIsConstant = Analysis::IsCompileTimeConstant(*right);
879     if (leftSideIsConstant && rightSideIsConstant) {
880         return fold_two_constants(context, pos, left, op, right, resultType);
881     }
882 
883     if (context.fConfig->fSettings.fOptimize) {
884         // If just one side is constant, we might still be able to simplify arithmetic expressions
885         // like `x * 1`, `x *= 1`, `x + 0`, `x * 0`, `0 / x`, etc.
886         if (leftSideIsConstant || rightSideIsConstant) {
887             if (std::unique_ptr<Expression> expr = simplify_arithmetic(context, pos, *left, op,
888                                                                        *right, resultType)) {
889                 return expr;
890             }
891         }
892 
893         // We can simplify some forms of matrix division even when neither side is constant.
894         if (std::unique_ptr<Expression> expr = simplify_matrix_division(context, pos, *left, op,
895                                                                         *right, resultType)) {
896             return expr;
897         }
898     }
899 
900     // We aren't able to constant-fold.
901     return nullptr;
902 }
903 
904 }  // namespace SkSL
905