xref: /aosp_15_r20/external/skia/src/sksl/ir/SkSLConstructorCompound.cpp (revision c8dee2aa9b3f27cf6c858bd81872bdeb2c07ed17)
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/ir/SkSLConstructorCompound.h"
9 
10 #include "include/core/SkTypes.h"
11 #include "include/private/base/SkTArray.h"
12 #include "src/sksl/SkSLAnalysis.h"
13 #include "src/sksl/SkSLConstantFolder.h"
14 #include "src/sksl/SkSLContext.h"
15 #include "src/sksl/SkSLProgramSettings.h"
16 #include "src/sksl/ir/SkSLConstructorSplat.h"
17 #include "src/sksl/ir/SkSLExpression.h"
18 #include "src/sksl/ir/SkSLLiteral.h"
19 #include "src/sksl/ir/SkSLType.h"
20 
21 #include <algorithm>
22 #include <cstddef>
23 #include <numeric>
24 #include <string>
25 
26 namespace SkSL {
27 
is_safe_to_eliminate(const Type & type,const Expression & arg)28 static bool is_safe_to_eliminate(const Type& type, const Expression& arg) {
29     if (type.isScalar()) {
30         // A scalar "compound type" with a single scalar argument is a no-op and can be eliminated.
31         // (Pedantically, this isn't a compound at all, but it's harmless to allow and simplifies
32         // call sites which need to narrow a vector and may sometimes end up with a scalar.)
33         SkASSERTF(arg.type().matches(type), "Creating type '%s' from '%s'",
34                   type.description().c_str(), arg.type().description().c_str());
35         return true;
36     }
37     if (type.isVector() && arg.type().matches(type)) {
38         // A vector compound constructor containing a single argument of matching type can trivially
39         // be eliminated.
40         return true;
41     }
42     // This is a meaningful single-argument compound constructor (e.g. vector-from-matrix,
43     // matrix-from-vector).
44     return false;
45 }
46 
make_splat_from_arguments(const Type & type,const ExpressionArray & args)47 static const Expression* make_splat_from_arguments(const Type& type, const ExpressionArray& args) {
48     // Splats cannot represent a matrix.
49     if (type.isMatrix()) {
50         return nullptr;
51     }
52     const Expression* splatExpression = nullptr;
53     for (int index = 0; index < args.size(); ++index) {
54         // Arguments must only be scalars or a splat constructors (which can only contain scalars).
55         const Expression* expr;
56         if (args[index]->type().isScalar()) {
57             expr = args[index].get();
58         } else if (args[index]->is<ConstructorSplat>()) {
59             expr = args[index]->as<ConstructorSplat>().argument().get();
60         } else {
61             return nullptr;
62         }
63         // On the first iteration, just remember the expression we encountered.
64         if (index == 0) {
65             splatExpression = expr;
66             continue;
67         }
68         // On subsequent iterations, ensure that the expression we found matches the first one.
69         // (Note that IsSameExpressionTree will always reject an Expression with side effects.)
70         if (!Analysis::IsSameExpressionTree(*expr, *splatExpression)) {
71             return nullptr;
72         }
73     }
74 
75     return splatExpression;
76 }
77 
Make(const Context & context,Position pos,const Type & type,ExpressionArray args)78 std::unique_ptr<Expression> ConstructorCompound::Make(const Context& context,
79                                                       Position pos,
80                                                       const Type& type,
81                                                       ExpressionArray args) {
82     SkASSERT(type.isAllowedInES2(context));
83 
84     // All the arguments must have matching component type.
85     SkASSERT(std::all_of(args.begin(), args.end(), [&](const std::unique_ptr<Expression>& arg) {
86         const Type& argType = arg->type();
87         return (argType.isScalar() || argType.isVector() || argType.isMatrix()) &&
88                (argType.componentType().matches(type.componentType()));
89     }));
90 
91     // The slot count of the combined argument list must match the composite type's slot count.
92     SkASSERT(type.slotCount() ==
93              std::accumulate(args.begin(), args.end(), /*initial value*/ (size_t)0,
94                              [](size_t n, const std::unique_ptr<Expression>& arg) {
95                                  return n + arg->type().slotCount();
96                              }));
97     // No-op compound constructors (containing a single argument of the same type) are eliminated.
98     // (Even though this is a "compound constructor," we let scalars pass through here; it's
99     // harmless to allow and simplifies call sites which need to narrow a vector and may sometimes
100     // end up with a scalar.)
101     if (args.size() == 1 && is_safe_to_eliminate(type, *args.front())) {
102         args.front()->fPosition = pos;
103         return std::move(args.front());
104     }
105     // Beyond this point, the type must be a vector or matrix.
106     SkASSERT(type.isVector() || type.isMatrix());
107 
108     if (context.fConfig->fSettings.fOptimize) {
109         // Find ConstructorCompounds embedded inside other ConstructorCompounds and flatten them.
110         //   -  float4(float2(1, 2), 3, 4)                -->  float4(1, 2, 3, 4)
111         //   -  float4(w, float3(sin(x), cos(y), tan(z))) -->  float4(w, sin(x), cos(y), tan(z))
112         //   -  mat2(float2(a, b), float2(c, d))          -->  mat2(a, b, c, d)
113 
114         // See how many fields we would have if composite constructors were flattened out.
115         int fields = 0;
116         for (const std::unique_ptr<Expression>& arg : args) {
117             fields += arg->is<ConstructorCompound>()
118                               ? arg->as<ConstructorCompound>().arguments().size()
119                               : 1;
120         }
121 
122         // If we added up more fields than we're starting with, we found at least one input that can
123         // be flattened out.
124         if (fields > args.size()) {
125             ExpressionArray flattened;
126             flattened.reserve_exact(fields);
127             for (std::unique_ptr<Expression>& arg : args) {
128                 // For non-ConstructorCompound fields, move them over as-is.
129                 if (!arg->is<ConstructorCompound>()) {
130                     flattened.push_back(std::move(arg));
131                     continue;
132                 }
133                 // For ConstructorCompound fields, move over their inner arguments individually.
134                 ConstructorCompound& compositeCtor = arg->as<ConstructorCompound>();
135                 for (std::unique_ptr<Expression>& innerArg : compositeCtor.arguments()) {
136                     flattened.push_back(std::move(innerArg));
137                 }
138             }
139             args = std::move(flattened);
140         }
141     }
142 
143     // Replace constant variables with their corresponding values, so `float2(one, two)` can
144     // compile down to `float2(1.0, 2.0)` (the latter is a compile-time constant).
145     for (std::unique_ptr<Expression>& arg : args) {
146         arg = ConstantFolder::MakeConstantValueForVariable(pos, std::move(arg));
147     }
148 
149     if (context.fConfig->fSettings.fOptimize) {
150         // Reduce compound constructors to splats where possible.
151         if (const Expression* splat = make_splat_from_arguments(type, args)) {
152             return ConstructorSplat::Make(context, pos, type, splat->clone());
153         }
154     }
155 
156     return std::make_unique<ConstructorCompound>(pos, type, std::move(args));
157 }
158 
MakeFromConstants(const Context & context,Position pos,const Type & returnType,const double value[])159 std::unique_ptr<Expression> ConstructorCompound::MakeFromConstants(const Context& context,
160                                                                    Position pos,
161                                                                    const Type& returnType,
162                                                                    const double value[]) {
163     int numSlots = returnType.slotCount();
164     ExpressionArray array;
165     array.reserve_exact(numSlots);
166     for (int index = 0; index < numSlots; ++index) {
167         array.push_back(Literal::Make(pos, value[index], &returnType.componentType()));
168     }
169     return ConstructorCompound::Make(context, pos, returnType, std::move(array));
170 }
171 
172 }  // namespace SkSL
173