xref: /aosp_15_r20/external/skia/src/sksl/ir/SkSLIndexExpression.cpp (revision c8dee2aa9b3f27cf6c858bd81872bdeb2c07ed17)
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/ir/SkSLIndexExpression.h"
9 
10 #include "include/core/SkTypes.h"
11 #include "include/private/base/SkTArray.h"
12 #include "src/sksl/SkSLAnalysis.h"
13 #include "src/sksl/SkSLBuiltinTypes.h"
14 #include "src/sksl/SkSLConstantFolder.h"
15 #include "src/sksl/SkSLContext.h"
16 #include "src/sksl/SkSLDefines.h"
17 #include "src/sksl/SkSLErrorReporter.h"
18 #include "src/sksl/SkSLOperator.h"
19 #include "src/sksl/ir/SkSLConstructorArray.h"
20 #include "src/sksl/ir/SkSLConstructorCompound.h"
21 #include "src/sksl/ir/SkSLLiteral.h"
22 #include "src/sksl/ir/SkSLSwizzle.h"
23 #include "src/sksl/ir/SkSLSymbolTable.h"  // IWYU pragma: keep
24 #include "src/sksl/ir/SkSLType.h"
25 #include "src/sksl/ir/SkSLTypeReference.h"
26 
27 #include <cstdint>
28 #include <optional>
29 
30 namespace SkSL {
31 
index_out_of_range(const Context & context,Position pos,SKSL_INT index,const Expression & base)32 static bool index_out_of_range(const Context& context, Position pos, SKSL_INT index,
33         const Expression& base) {
34     if (index >= 0) {
35         if (base.type().columns() == Type::kUnsizedArray) {
36             return false;
37         } else if (index < base.type().columns()) {
38             return false;
39         }
40     }
41     context.fErrors->error(pos, "index " + std::to_string(index) + " out of range for '" +
42                                 base.type().displayName() + "'");
43     return true;
44 }
45 
IndexType(const Context & context,const Type & type)46 const Type& IndexExpression::IndexType(const Context& context, const Type& type) {
47     if (type.isMatrix()) {
48         if (type.componentType().matches(*context.fTypes.fFloat)) {
49             switch (type.rows()) {
50                 case 2: return *context.fTypes.fFloat2;
51                 case 3: return *context.fTypes.fFloat3;
52                 case 4: return *context.fTypes.fFloat4;
53                 default: SkASSERT(false);
54             }
55         } else if (type.componentType().matches(*context.fTypes.fHalf)) {
56             switch (type.rows()) {
57                 case 2: return *context.fTypes.fHalf2;
58                 case 3: return *context.fTypes.fHalf3;
59                 case 4: return *context.fTypes.fHalf4;
60                 default: SkASSERT(false);
61             }
62         }
63     }
64     return type.componentType();
65 }
66 
Convert(const Context & context,Position pos,std::unique_ptr<Expression> base,std::unique_ptr<Expression> index)67 std::unique_ptr<Expression> IndexExpression::Convert(const Context& context,
68                                                      Position pos,
69                                                      std::unique_ptr<Expression> base,
70                                                      std::unique_ptr<Expression> index) {
71     // Convert an array type reference: `int[10]`.
72     if (base->is<TypeReference>()) {
73         const Type& baseType = base->as<TypeReference>().value();
74         SKSL_INT arraySize = baseType.convertArraySize(context, pos, std::move(index));
75         if (!arraySize) {
76             return nullptr;
77         }
78         return TypeReference::Convert(
79                 context,
80                 pos,
81                 context.fSymbolTable->addArrayDimension(context, &baseType, arraySize));
82     }
83     // Convert an index expression with an expression inside of it: `arr[a * 3]`.
84     const Type& baseType = base->type();
85     if (!baseType.isArray() && !baseType.isMatrix() && !baseType.isVector()) {
86         context.fErrors->error(base->fPosition,
87                                "expected array, but found '" + baseType.displayName() + "'");
88         return nullptr;
89     }
90     if (!index->type().isInteger()) {
91         index = context.fTypes.fInt->coerceExpression(std::move(index), context);
92         if (!index) {
93             return nullptr;
94         }
95     }
96     // Perform compile-time bounds checking on constant-expression indices.
97     const Expression* indexExpr = ConstantFolder::GetConstantValueForVariable(*index);
98     if (indexExpr->isIntLiteral()) {
99         SKSL_INT indexValue = indexExpr->as<Literal>().intValue();
100         if (index_out_of_range(context, index->fPosition, indexValue, *base)) {
101             return nullptr;
102         }
103     }
104     return IndexExpression::Make(context, pos, std::move(base), std::move(index));
105 }
106 
Make(const Context & context,Position pos,std::unique_ptr<Expression> base,std::unique_ptr<Expression> index)107 std::unique_ptr<Expression> IndexExpression::Make(const Context& context,
108                                                   Position pos,
109                                                   std::unique_ptr<Expression> base,
110                                                   std::unique_ptr<Expression> index) {
111     const Type& baseType = base->type();
112     SkASSERT(baseType.isArray() || baseType.isMatrix() || baseType.isVector());
113     SkASSERT(index->type().isInteger());
114 
115     const Expression* indexExpr = ConstantFolder::GetConstantValueForVariable(*index);
116     if (indexExpr->isIntLiteral()) {
117         SKSL_INT indexValue = indexExpr->as<Literal>().intValue();
118         if (!index_out_of_range(context, index->fPosition, indexValue, *base)) {
119             if (baseType.isVector()) {
120                 // Constant array indexes on vectors can be converted to swizzles: `v[2]` --> `v.z`.
121                 // Swizzling is harmless and can unlock further simplifications for some base types.
122                 return Swizzle::Make(context, pos, std::move(base),
123                         ComponentArray{(int8_t)indexValue});
124             }
125 
126             if (baseType.isArray() && !Analysis::HasSideEffects(*base)) {
127                 // Indexing an constant array constructor with a constant index can just pluck out
128                 // the requested value from the array.
129                 const Expression* baseExpr = ConstantFolder::GetConstantValueForVariable(*base);
130                 if (baseExpr->is<ConstructorArray>()) {
131                     const ConstructorArray& arrayCtor = baseExpr->as<ConstructorArray>();
132                     const ExpressionArray& arguments = arrayCtor.arguments();
133                     SkASSERT(arguments.size() == baseType.columns());
134 
135                     return arguments[indexValue]->clone(pos);
136                 }
137             }
138 
139             if (baseType.isMatrix() && !Analysis::HasSideEffects(*base)) {
140                 // Matrices can be constructed with vectors that don't line up on column boundaries,
141                 // so extracting out the values from the constructor can be tricky. Fortunately, we
142                 // can reconstruct an equivalent vector using `getConstantValue`. If we
143                 // can't extract the data using `getConstantValue`, it wasn't constant and
144                 // we're not obligated to simplify anything.
145                 const Expression* baseExpr = ConstantFolder::GetConstantValueForVariable(*base);
146                 int vecWidth = baseType.rows();
147                 const Type& vecType = baseType.columnType(context);
148                 indexValue *= vecWidth;
149 
150                 double ctorArgs[4];
151                 bool allConstant = true;
152                 for (int slot = 0; slot < vecWidth; ++slot) {
153                     std::optional<double> slotVal = baseExpr->getConstantValue(indexValue + slot);
154                     if (slotVal.has_value()) {
155                         ctorArgs[slot] = *slotVal;
156                     } else {
157                         allConstant = false;
158                         break;
159                     }
160                 }
161 
162                 if (allConstant) {
163                     return ConstructorCompound::MakeFromConstants(context, pos, vecType, ctorArgs);
164                 }
165             }
166         }
167     }
168 
169     return std::make_unique<IndexExpression>(context, pos, std::move(base), std::move(index));
170 }
171 
description(OperatorPrecedence) const172 std::string IndexExpression::description(OperatorPrecedence) const {
173     return this->base()->description(OperatorPrecedence::kPostfix) + "[" +
174            this->index()->description(OperatorPrecedence::kExpression) + "]";
175 }
176 
177 }  // namespace SkSL
178