xref: /aosp_15_r20/external/skia/src/sksl/ir/SkSLFunctionCall.cpp (revision c8dee2aa9b3f27cf6c858bd81872bdeb2c07ed17)
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/ir/SkSLFunctionCall.h"
9 
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "include/private/base/SkFloatingPoint.h"
13 #include "include/private/base/SkTArray.h"
14 #include "include/private/base/SkTo.h"
15 #include "src/base/SkEnumBitMask.h"
16 #include "src/base/SkHalf.h"
17 #include "src/core/SkMatrixInvert.h"
18 #include "src/sksl/SkSLAnalysis.h"
19 #include "src/sksl/SkSLBuiltinTypes.h"
20 #include "src/sksl/SkSLConstantFolder.h"
21 #include "src/sksl/SkSLContext.h"
22 #include "src/sksl/SkSLErrorReporter.h"
23 #include "src/sksl/SkSLIntrinsicList.h"
24 #include "src/sksl/SkSLOperator.h"
25 #include "src/sksl/SkSLProgramSettings.h"
26 #include "src/sksl/SkSLString.h"
27 #include "src/sksl/ir/SkSLChildCall.h"
28 #include "src/sksl/ir/SkSLConstructor.h"
29 #include "src/sksl/ir/SkSLConstructorCompound.h"
30 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
31 #include "src/sksl/ir/SkSLFunctionReference.h"
32 #include "src/sksl/ir/SkSLLayout.h"
33 #include "src/sksl/ir/SkSLLiteral.h"
34 #include "src/sksl/ir/SkSLMethodReference.h"
35 #include "src/sksl/ir/SkSLModifierFlags.h"
36 #include "src/sksl/ir/SkSLType.h"
37 #include "src/sksl/ir/SkSLTypeReference.h"
38 #include "src/sksl/ir/SkSLVariable.h"
39 #include "src/sksl/ir/SkSLVariableReference.h"
40 
41 #include <algorithm>
42 #include <array>
43 #include <cmath>
44 #include <cstdint>
45 #include <cstring>
46 #include <optional>
47 #include <string_view>
48 
49 namespace SkSL {
50 
51 using IntrinsicArguments = std::array<const Expression*, 3>;
52 
has_compile_time_constant_arguments(const ExpressionArray & arguments)53 static bool has_compile_time_constant_arguments(const ExpressionArray& arguments) {
54     for (const std::unique_ptr<Expression>& arg : arguments) {
55         const Expression* expr = ConstantFolder::GetConstantValueForVariable(*arg);
56         if (!Analysis::IsCompileTimeConstant(*expr)) {
57             return false;
58         }
59     }
60     return true;
61 }
62 
63 template <typename T>
64 static void type_check_expression(const Expression& expr);
65 
66 template <>
type_check_expression(const Expression & expr)67 void type_check_expression<float>(const Expression& expr) {
68     SkASSERT(expr.type().componentType().isFloat());
69 }
70 
71 template <>
type_check_expression(const Expression & expr)72 void type_check_expression<SKSL_INT>(const Expression& expr) {
73     SkASSERT(expr.type().componentType().isInteger());
74 }
75 
76 template <>
type_check_expression(const Expression & expr)77 void type_check_expression<bool>(const Expression& expr) {
78     SkASSERT(expr.type().componentType().isBoolean());
79 }
80 
81 using CoalesceFn = double (*)(double, double, double);
82 using FinalizeFn = double (*)(double);
83 
coalesce_n_way_vector(const Expression * arg0,const Expression * arg1,double startingState,const Type & returnType,CoalesceFn coalesce,FinalizeFn finalize)84 static std::unique_ptr<Expression> coalesce_n_way_vector(const Expression* arg0,
85                                                          const Expression* arg1,
86                                                          double startingState,
87                                                          const Type& returnType,
88                                                          CoalesceFn coalesce,
89                                                          FinalizeFn finalize) {
90     // Takes up to two vector or scalar arguments and coalesces them in sequence:
91     //     scalar = startingState;
92     //     scalar = coalesce(scalar, arg0.x, arg1.x);
93     //     scalar = coalesce(scalar, arg0.y, arg1.y);
94     //     scalar = coalesce(scalar, arg0.z, arg1.z);
95     //     scalar = coalesce(scalar, arg0.w, arg1.w);
96     //     scalar = finalize(scalar);
97     //
98     // If an argument is null, zero is passed to the coalesce function. If the arguments are a mix
99     // of scalars and vectors, the scalars are interpreted as a vector containing the same value for
100     // every component.
101 
102     Position pos = arg0->fPosition;
103     double minimumValue = returnType.componentType().minimumValue();
104     double maximumValue = returnType.componentType().maximumValue();
105 
106     const Type& vecType =          arg0->type().isVector()  ? arg0->type() :
107                           (arg1 && arg1->type().isVector()) ? arg1->type() :
108                                                               arg0->type();
109     SkASSERT(         arg0->type().componentType().matches(vecType.componentType()));
110     SkASSERT(!arg1 || arg1->type().componentType().matches(vecType.componentType()));
111 
112     double value = startingState;
113     int arg0Index = 0;
114     int arg1Index = 0;
115     for (int index = 0; index < vecType.columns(); ++index) {
116         std::optional<double> arg0Value = arg0->getConstantValue(arg0Index);
117         arg0Index += arg0->type().isVector() ? 1 : 0;
118         SkASSERT(arg0Value.has_value());
119 
120         std::optional<double> arg1Value = 0.0;
121         if (arg1) {
122             arg1Value = arg1->getConstantValue(arg1Index);
123             arg1Index += arg1->type().isVector() ? 1 : 0;
124             SkASSERT(arg1Value.has_value());
125         }
126 
127         value = coalesce(value, *arg0Value, *arg1Value);
128 
129         if (value >= minimumValue && value <= maximumValue) {
130             // This result will fit inside the return type.
131         } else {
132             // The value is outside the float range or is NaN (all if-checks fail); do not optimize.
133             return nullptr;
134         }
135     }
136 
137     if (finalize) {
138         value = finalize(value);
139     }
140 
141     return Literal::Make(pos, value, &returnType);
142 }
143 
144 template <typename T>
coalesce_vector(const IntrinsicArguments & arguments,double startingState,const Type & returnType,CoalesceFn coalesce,FinalizeFn finalize)145 static std::unique_ptr<Expression> coalesce_vector(const IntrinsicArguments& arguments,
146                                                    double startingState,
147                                                    const Type& returnType,
148                                                    CoalesceFn coalesce,
149                                                    FinalizeFn finalize) {
150     SkASSERT(arguments[0]);
151     SkASSERT(!arguments[1]);
152     type_check_expression<T>(*arguments[0]);
153 
154     return coalesce_n_way_vector(arguments[0], /*arg1=*/nullptr,
155                                  startingState, returnType, coalesce, finalize);
156 }
157 
158 template <typename T>
coalesce_pairwise_vectors(const IntrinsicArguments & arguments,double startingState,const Type & returnType,CoalesceFn coalesce,FinalizeFn finalize)159 static std::unique_ptr<Expression> coalesce_pairwise_vectors(const IntrinsicArguments& arguments,
160                                                              double startingState,
161                                                              const Type& returnType,
162                                                              CoalesceFn coalesce,
163                                                              FinalizeFn finalize) {
164     SkASSERT(arguments[0]);
165     SkASSERT(arguments[1]);
166     SkASSERT(!arguments[2]);
167     type_check_expression<T>(*arguments[0]);
168     type_check_expression<T>(*arguments[1]);
169 
170     return coalesce_n_way_vector(arguments[0], arguments[1],
171                                  startingState, returnType, coalesce, finalize);
172 }
173 
174 using CompareFn = bool (*)(double, double);
175 
optimize_comparison(const Context & context,const IntrinsicArguments & arguments,CompareFn compare)176 static std::unique_ptr<Expression> optimize_comparison(const Context& context,
177                                                        const IntrinsicArguments& arguments,
178                                                        CompareFn compare) {
179     const Expression* left = arguments[0];
180     const Expression* right = arguments[1];
181     SkASSERT(left);
182     SkASSERT(right);
183     SkASSERT(!arguments[2]);
184 
185     const Type& type = left->type();
186     SkASSERT(type.isVector());
187     SkASSERT(type.componentType().isScalar());
188     SkASSERT(type.matches(right->type()));
189 
190     double array[4];
191 
192     for (int index = 0; index < type.columns(); ++index) {
193         std::optional<double> leftValue = left->getConstantValue(index);
194         std::optional<double> rightValue = right->getConstantValue(index);
195         SkASSERT(leftValue.has_value());
196         SkASSERT(rightValue.has_value());
197         array[index] = compare(*leftValue, *rightValue) ? 1.0 : 0.0;
198     }
199 
200     const Type& bvecType = context.fTypes.fBool->toCompound(context, type.columns(), /*rows=*/1);
201     return ConstructorCompound::MakeFromConstants(context, left->fPosition, bvecType, array);
202 }
203 
204 using EvaluateFn = double (*)(double, double, double);
205 
evaluate_n_way_intrinsic(const Context & context,const Expression * arg0,const Expression * arg1,const Expression * arg2,const Type & returnType,EvaluateFn eval)206 static std::unique_ptr<Expression> evaluate_n_way_intrinsic(const Context& context,
207                                                             const Expression* arg0,
208                                                             const Expression* arg1,
209                                                             const Expression* arg2,
210                                                             const Type& returnType,
211                                                             EvaluateFn eval) {
212     // Takes up to three arguments and evaluates all of them, left-to-right, in tandem.
213     // Equivalent to constructing a new compound value containing the results from:
214     //     eval(arg0.x, arg1.x, arg2.x),
215     //     eval(arg0.y, arg1.y, arg2.y),
216     //     eval(arg0.z, arg1.z, arg2.z),
217     //     eval(arg0.w, arg1.w, arg2.w)
218     //
219     // If an argument is null, zero is passed to the evaluation function. If the arguments are a mix
220     // of scalars and compounds, scalars are interpreted as a compound containing the same value for
221     // every component.
222 
223     double minimumValue = returnType.componentType().minimumValue();
224     double maximumValue = returnType.componentType().maximumValue();
225     int slots = returnType.slotCount();
226     double array[16];
227 
228     int arg0Index = 0;
229     int arg1Index = 0;
230     int arg2Index = 0;
231     for (int index = 0; index < slots; ++index) {
232         std::optional<double> arg0Value = arg0->getConstantValue(arg0Index);
233         arg0Index += arg0->type().isScalar() ? 0 : 1;
234         SkASSERT(arg0Value.has_value());
235 
236         std::optional<double> arg1Value = 0.0;
237         if (arg1) {
238             arg1Value = arg1->getConstantValue(arg1Index);
239             arg1Index += arg1->type().isScalar() ? 0 : 1;
240             SkASSERT(arg1Value.has_value());
241         }
242 
243         std::optional<double> arg2Value = 0.0;
244         if (arg2) {
245             arg2Value = arg2->getConstantValue(arg2Index);
246             arg2Index += arg2->type().isScalar() ? 0 : 1;
247             SkASSERT(arg2Value.has_value());
248         }
249 
250         array[index] = eval(*arg0Value, *arg1Value, *arg2Value);
251 
252         if (array[index] >= minimumValue && array[index] <= maximumValue) {
253             // This result will fit inside the return type.
254         } else {
255             // The value is outside the float range or is NaN (all if-checks fail); do not optimize.
256             return nullptr;
257         }
258     }
259 
260     return ConstructorCompound::MakeFromConstants(context, arg0->fPosition, returnType, array);
261 }
262 
263 template <typename T>
evaluate_intrinsic(const Context & context,const IntrinsicArguments & arguments,const Type & returnType,EvaluateFn eval)264 static std::unique_ptr<Expression> evaluate_intrinsic(const Context& context,
265                                                       const IntrinsicArguments& arguments,
266                                                       const Type& returnType,
267                                                       EvaluateFn eval) {
268     SkASSERT(arguments[0]);
269     SkASSERT(!arguments[1]);
270     type_check_expression<T>(*arguments[0]);
271 
272     return evaluate_n_way_intrinsic(context, arguments[0], /*arg1=*/nullptr, /*arg2=*/nullptr,
273                                     returnType, eval);
274 }
275 
evaluate_intrinsic_numeric(const Context & context,const IntrinsicArguments & arguments,const Type & returnType,EvaluateFn eval)276 static std::unique_ptr<Expression> evaluate_intrinsic_numeric(const Context& context,
277                                                               const IntrinsicArguments& arguments,
278                                                               const Type& returnType,
279                                                               EvaluateFn eval) {
280     SkASSERT(arguments[0]);
281     SkASSERT(!arguments[1]);
282     const Type& type = arguments[0]->type().componentType();
283 
284     if (type.isFloat()) {
285         return evaluate_intrinsic<float>(context, arguments, returnType, eval);
286     }
287     if (type.isInteger()) {
288         return evaluate_intrinsic<SKSL_INT>(context, arguments, returnType, eval);
289     }
290 
291     SkDEBUGFAILF("unsupported type %s", type.description().c_str());
292     return nullptr;
293 }
294 
evaluate_pairwise_intrinsic(const Context & context,const IntrinsicArguments & arguments,const Type & returnType,EvaluateFn eval)295 static std::unique_ptr<Expression> evaluate_pairwise_intrinsic(const Context& context,
296                                                                const IntrinsicArguments& arguments,
297                                                                const Type& returnType,
298                                                                EvaluateFn eval) {
299     SkASSERT(arguments[0]);
300     SkASSERT(arguments[1]);
301     SkASSERT(!arguments[2]);
302     const Type& type = arguments[0]->type().componentType();
303 
304     if (type.isFloat()) {
305         type_check_expression<float>(*arguments[0]);
306         type_check_expression<float>(*arguments[1]);
307     } else if (type.isInteger()) {
308         type_check_expression<SKSL_INT>(*arguments[0]);
309         type_check_expression<SKSL_INT>(*arguments[1]);
310     } else {
311         SkDEBUGFAILF("unsupported type %s", type.description().c_str());
312         return nullptr;
313     }
314 
315     return evaluate_n_way_intrinsic(context, arguments[0], arguments[1], /*arg2=*/nullptr,
316                                     returnType, eval);
317 }
318 
evaluate_3_way_intrinsic(const Context & context,const IntrinsicArguments & arguments,const Type & returnType,EvaluateFn eval)319 static std::unique_ptr<Expression> evaluate_3_way_intrinsic(const Context& context,
320                                                             const IntrinsicArguments& arguments,
321                                                             const Type& returnType,
322                                                             EvaluateFn eval) {
323     SkASSERT(arguments[0]);
324     SkASSERT(arguments[1]);
325     SkASSERT(arguments[2]);
326     const Type& type = arguments[0]->type().componentType();
327 
328     if (type.isFloat()) {
329         type_check_expression<float>(*arguments[0]);
330         type_check_expression<float>(*arguments[1]);
331         type_check_expression<float>(*arguments[2]);
332     } else if (type.isInteger()) {
333         type_check_expression<SKSL_INT>(*arguments[0]);
334         type_check_expression<SKSL_INT>(*arguments[1]);
335         type_check_expression<SKSL_INT>(*arguments[2]);
336     } else {
337         SkDEBUGFAILF("unsupported type %s", type.description().c_str());
338         return nullptr;
339     }
340 
341     return evaluate_n_way_intrinsic(context, arguments[0], arguments[1], arguments[2],
342                                     returnType, eval);
343 }
344 
345 template <typename T1, typename T2>
pun_value(double val)346 static double pun_value(double val) {
347     // Interpret `val` as a value of type T1.
348     static_assert(sizeof(T1) == sizeof(T2));
349     T1 inputValue = (T1)val;
350     // Reinterpret those bits as a value of type T2.
351     T2 outputValue;
352     memcpy(&outputValue, &inputValue, sizeof(T2));
353     // Return the value-of-type-T2 as a double. (Non-finite values will prohibit optimization.)
354     return (double)outputValue;
355 }
356 
357 // Helper functions for optimizing all of our intrinsics.
358 namespace Intrinsics {
359 namespace {
360 
coalesce_length(double a,double b,double)361 double coalesce_length(double a, double b, double)     { return a + (b * b); }
finalize_length(double a)362 double finalize_length(double a)                       { return std::sqrt(a); }
363 
coalesce_distance(double a,double b,double c)364 double coalesce_distance(double a, double b, double c) { b -= c; return a + (b * b); }
finalize_distance(double a)365 double finalize_distance(double a)                     { return std::sqrt(a); }
366 
coalesce_dot(double a,double b,double c)367 double coalesce_dot(double a, double b, double c)      { return a + (b * c); }
coalesce_any(double a,double b,double)368 double coalesce_any(double a, double b, double)        { return a || b; }
coalesce_all(double a,double b,double)369 double coalesce_all(double a, double b, double)        { return a && b; }
370 
compare_lessThan(double a,double b)371 bool compare_lessThan(double a, double b)              { return a < b; }
compare_lessThanEqual(double a,double b)372 bool compare_lessThanEqual(double a, double b)         { return a <= b; }
compare_greaterThan(double a,double b)373 bool compare_greaterThan(double a, double b)           { return a > b; }
compare_greaterThanEqual(double a,double b)374 bool compare_greaterThanEqual(double a, double b)      { return a >= b; }
compare_equal(double a,double b)375 bool compare_equal(double a, double b)                 { return a == b; }
compare_notEqual(double a,double b)376 bool compare_notEqual(double a, double b)              { return a != b; }
377 
evaluate_radians(double a,double,double)378 double evaluate_radians(double a, double, double)      { return a * 0.0174532925; }
evaluate_degrees(double a,double,double)379 double evaluate_degrees(double a, double, double)      { return a * 57.2957795; }
evaluate_sin(double a,double,double)380 double evaluate_sin(double a, double, double)          { return std::sin(a); }
evaluate_cos(double a,double,double)381 double evaluate_cos(double a, double, double)          { return std::cos(a); }
evaluate_tan(double a,double,double)382 double evaluate_tan(double a, double, double)          { return std::tan(a); }
evaluate_asin(double a,double,double)383 double evaluate_asin(double a, double, double)         { return std::asin(a); }
evaluate_acos(double a,double,double)384 double evaluate_acos(double a, double, double)         { return std::acos(a); }
evaluate_atan(double a,double,double)385 double evaluate_atan(double a, double, double)         { return std::atan(a); }
evaluate_atan2(double a,double b,double)386 double evaluate_atan2(double a, double b, double)      { return std::atan2(a, b); }
evaluate_asinh(double a,double,double)387 double evaluate_asinh(double a, double, double)        { return std::asinh(a); }
evaluate_acosh(double a,double,double)388 double evaluate_acosh(double a, double, double)        { return std::acosh(a); }
evaluate_atanh(double a,double,double)389 double evaluate_atanh(double a, double, double)        { return std::atanh(a); }
390 
evaluate_pow(double a,double b,double)391 double evaluate_pow(double a, double b, double)        { return std::pow(a, b); }
evaluate_exp(double a,double,double)392 double evaluate_exp(double a, double, double)          { return std::exp(a); }
evaluate_log(double a,double,double)393 double evaluate_log(double a, double, double)          { return std::log(a); }
evaluate_exp2(double a,double,double)394 double evaluate_exp2(double a, double, double)         { return std::exp2(a); }
evaluate_log2(double a,double,double)395 double evaluate_log2(double a, double, double)         { return std::log2(a); }
evaluate_sqrt(double a,double,double)396 double evaluate_sqrt(double a, double, double)         { return std::sqrt(a); }
evaluate_inversesqrt(double a,double,double)397 double evaluate_inversesqrt(double a, double, double) {
398     return sk_ieee_double_divide(1.0, std::sqrt(a));
399 }
400 
evaluate_add(double a,double b,double)401 double evaluate_add(double a, double b, double)        { return a + b; }
evaluate_sub(double a,double b,double)402 double evaluate_sub(double a, double b, double)        { return a - b; }
evaluate_mul(double a,double b,double)403 double evaluate_mul(double a, double b, double)        { return a * b; }
evaluate_div(double a,double b,double)404 double evaluate_div(double a, double b, double)        { return sk_ieee_double_divide(a, b); }
evaluate_abs(double a,double,double)405 double evaluate_abs(double a, double, double)          { return std::abs(a); }
evaluate_sign(double a,double,double)406 double evaluate_sign(double a, double, double)         { return (a > 0) - (a < 0); }
evaluate_opposite_sign(double a,double,double)407 double evaluate_opposite_sign(double a,double, double) { return (a < 0) - (a > 0); }
evaluate_floor(double a,double,double)408 double evaluate_floor(double a, double, double)        { return std::floor(a); }
evaluate_ceil(double a,double,double)409 double evaluate_ceil(double a, double, double)         { return std::ceil(a); }
evaluate_fract(double a,double,double)410 double evaluate_fract(double a, double, double)        { return a - std::floor(a); }
evaluate_min(double a,double b,double)411 double evaluate_min(double a, double b, double)        { return (a < b) ? a : b; }
evaluate_max(double a,double b,double)412 double evaluate_max(double a, double b, double)        { return (a > b) ? a : b; }
evaluate_clamp(double x,double l,double h)413 double evaluate_clamp(double x, double l, double h)    { return (x < l) ? l : (x > h) ? h : x; }
evaluate_fma(double a,double b,double c)414 double evaluate_fma(double a, double b, double c)      { return a * b + c; }
evaluate_saturate(double a,double,double)415 double evaluate_saturate(double a, double, double)     { return (a < 0) ? 0 : (a > 1) ? 1 : a; }
evaluate_mix(double x,double y,double a)416 double evaluate_mix(double x, double y, double a)      { return x * (1 - a) + y * a; }
evaluate_step(double e,double x,double)417 double evaluate_step(double e, double x, double)       { return (x < e) ? 0 : 1; }
evaluate_mod(double a,double b,double)418 double evaluate_mod(double a, double b, double) {
419     return a - b * std::floor(sk_ieee_double_divide(a, b));
420 }
evaluate_smoothstep(double edge0,double edge1,double x)421 double evaluate_smoothstep(double edge0, double edge1, double x) {
422     double t = sk_ieee_double_divide(x - edge0, edge1 - edge0);
423     t = (t < 0) ? 0 : (t > 1) ? 1 : t;
424     return t * t * (3.0 - 2.0 * t);
425 }
426 
evaluate_matrixCompMult(double x,double y,double)427 double evaluate_matrixCompMult(double x, double y, double) { return x * y; }
428 
evaluate_not(double a,double,double)429 double evaluate_not(double a, double, double)          { return !a; }
evaluate_sinh(double a,double,double)430 double evaluate_sinh(double a, double, double)         { return std::sinh(a); }
evaluate_cosh(double a,double,double)431 double evaluate_cosh(double a, double, double)         { return std::cosh(a); }
evaluate_tanh(double a,double,double)432 double evaluate_tanh(double a, double, double)         { return std::tanh(a); }
evaluate_trunc(double a,double,double)433 double evaluate_trunc(double a, double, double)        { return std::trunc(a); }
evaluate_round(double a,double,double)434 double evaluate_round(double a, double, double) {
435     // The semantics of std::remainder guarantee a rounded-to-even result here, regardless of the
436     // current float-rounding mode.
437     return a - std::remainder(a, 1.0);
438 }
evaluate_floatBitsToInt(double a,double,double)439 double evaluate_floatBitsToInt(double a, double, double)  { return pun_value<float, int32_t> (a); }
evaluate_floatBitsToUint(double a,double,double)440 double evaluate_floatBitsToUint(double a, double, double) { return pun_value<float, uint32_t>(a); }
evaluate_intBitsToFloat(double a,double,double)441 double evaluate_intBitsToFloat(double a, double, double)  { return pun_value<int32_t,  float>(a); }
evaluate_uintBitsToFloat(double a,double,double)442 double evaluate_uintBitsToFloat(double a, double, double) { return pun_value<uint32_t, float>(a); }
443 
evaluate_length(const IntrinsicArguments & arguments)444 std::unique_ptr<Expression> evaluate_length(const IntrinsicArguments& arguments) {
445     return coalesce_vector<float>(arguments, /*startingState=*/0,
446                                   arguments[0]->type().componentType(),
447                                   coalesce_length,
448                                   finalize_length);
449 }
450 
evaluate_distance(const IntrinsicArguments & arguments)451 std::unique_ptr<Expression> evaluate_distance(const IntrinsicArguments& arguments) {
452     return coalesce_pairwise_vectors<float>(arguments, /*startingState=*/0,
453                                             arguments[0]->type().componentType(),
454                                             coalesce_distance,
455                                             finalize_distance);
456 }
evaluate_dot(const IntrinsicArguments & arguments)457 std::unique_ptr<Expression> evaluate_dot(const IntrinsicArguments& arguments) {
458     return coalesce_pairwise_vectors<float>(arguments, /*startingState=*/0,
459                                             arguments[0]->type().componentType(),
460                                             coalesce_dot,
461                                             /*finalize=*/nullptr);
462 }
463 
evaluate_sign(const Context & context,const IntrinsicArguments & arguments)464 std::unique_ptr<Expression> evaluate_sign(const Context& context,
465                                           const IntrinsicArguments& arguments) {
466     return evaluate_intrinsic_numeric(context, arguments, arguments[0]->type(),
467                                       evaluate_sign);
468 }
469 
evaluate_opposite_sign(const Context & context,const IntrinsicArguments & arguments)470 std::unique_ptr<Expression> evaluate_opposite_sign(const Context& context,
471                                                    const IntrinsicArguments& arguments) {
472     return evaluate_intrinsic_numeric(context, arguments, arguments[0]->type(),
473                                       evaluate_opposite_sign);
474 }
475 
evaluate_add(const Context & context,const IntrinsicArguments & arguments)476 std::unique_ptr<Expression> evaluate_add(const Context& context,
477                                          const IntrinsicArguments& arguments) {
478     return evaluate_pairwise_intrinsic(context, arguments, arguments[0]->type(),
479                                        evaluate_add);
480 }
481 
evaluate_sub(const Context & context,const IntrinsicArguments & arguments)482 std::unique_ptr<Expression> evaluate_sub(const Context& context,
483                                          const IntrinsicArguments& arguments) {
484     return evaluate_pairwise_intrinsic(context, arguments, arguments[0]->type(),
485                                        evaluate_sub);
486 }
487 
evaluate_mul(const Context & context,const IntrinsicArguments & arguments)488 std::unique_ptr<Expression> evaluate_mul(const Context& context,
489                                          const IntrinsicArguments& arguments) {
490     return evaluate_pairwise_intrinsic(context, arguments, arguments[0]->type(),
491                                        evaluate_mul);
492 }
493 
evaluate_div(const Context & context,const IntrinsicArguments & arguments)494 std::unique_ptr<Expression> evaluate_div(const Context& context,
495                                          const IntrinsicArguments& arguments) {
496     return evaluate_pairwise_intrinsic(context, arguments, arguments[0]->type(),
497                                        evaluate_div);
498 }
499 
evaluate_normalize(const Context & context,const IntrinsicArguments & arguments)500 std::unique_ptr<Expression> evaluate_normalize(const Context& context,
501                                                const IntrinsicArguments& arguments) {
502     // normalize(v): v / length(v)
503     std::unique_ptr<Expression> length = Intrinsics::evaluate_length(arguments);
504     if (!length) { return nullptr; }
505 
506     const IntrinsicArguments divArgs = {arguments[0], length.get(), nullptr};
507     return Intrinsics::evaluate_div(context, divArgs);
508 }
509 
evaluate_faceforward(const Context & context,const IntrinsicArguments & arguments)510 std::unique_ptr<Expression> evaluate_faceforward(const Context& context,
511                                                  const IntrinsicArguments& arguments) {
512     const Expression* N = arguments[0];     // vector
513     const Expression* I = arguments[1];     // vector
514     const Expression* NRef = arguments[2];  // vector
515 
516     // faceforward(N,I,NRef): N * -sign(dot(I, NRef))
517     const IntrinsicArguments dotArgs = {I, NRef, nullptr};
518     std::unique_ptr<Expression> dotExpr = Intrinsics::evaluate_dot(dotArgs);
519     if (!dotExpr) { return nullptr; }
520 
521     const IntrinsicArguments signArgs = {dotExpr.get(), nullptr, nullptr};
522     std::unique_ptr<Expression> signExpr = Intrinsics::evaluate_opposite_sign(context, signArgs);
523     if (!signExpr) { return nullptr; }
524 
525     const IntrinsicArguments mulArgs = {N, signExpr.get(), nullptr};
526     return Intrinsics::evaluate_mul(context, mulArgs);
527 }
528 
evaluate_reflect(const Context & context,const IntrinsicArguments & arguments)529 std::unique_ptr<Expression> evaluate_reflect(const Context& context,
530                                              const IntrinsicArguments& arguments) {
531     const Expression* I = arguments[0];  // vector
532     const Expression* N = arguments[1];  // vector
533 
534     // reflect(I,N): temp = (N * dot(N, I)); reflect = I - (temp + temp)
535     const IntrinsicArguments dotArgs = {N, I, nullptr};
536     std::unique_ptr<Expression> dotExpr = Intrinsics::evaluate_dot(dotArgs);
537     if (!dotExpr) { return nullptr; }
538 
539     const IntrinsicArguments mulArgs = {N, dotExpr.get(), nullptr};
540     std::unique_ptr<Expression> mulExpr = Intrinsics::evaluate_mul(context, mulArgs);
541     if (!mulExpr) { return nullptr; }
542 
543     const IntrinsicArguments addArgs = {mulExpr.get(), mulExpr.get(), nullptr};
544     std::unique_ptr<Expression> addExpr = Intrinsics::evaluate_add(context, addArgs);
545     if (!addExpr) { return nullptr; }
546 
547     const IntrinsicArguments subArgs = {I, addExpr.get(), nullptr};
548     return Intrinsics::evaluate_sub(context, subArgs);
549 }
550 
evaluate_refract(const Context & context,const IntrinsicArguments & arguments)551 std::unique_ptr<Expression> evaluate_refract(const Context& context,
552                                              const IntrinsicArguments& arguments) {
553     const Expression* I = arguments[0];    // vector
554     const Expression* N = arguments[1];    // vector
555     const Expression* Eta = arguments[2];  // scalar
556 
557     // K = 1.0 - Eta^2 * (1.0 - Dot(N, I)^2);
558 
559     // DotNI = Dot(N, I)
560     const IntrinsicArguments DotNIArgs = {N, I, nullptr};
561     std::unique_ptr<Expression> DotNIExpr = Intrinsics::evaluate_dot(DotNIArgs);
562     if (!DotNIExpr) { return nullptr; }
563 
564     // DotNI2 = DotNI * DotNI
565     const IntrinsicArguments DotNI2Args = {DotNIExpr.get(), DotNIExpr.get(), nullptr};
566     std::unique_ptr<Expression> DotNI2Expr = Intrinsics::evaluate_mul(context, DotNI2Args);
567     if (!DotNI2Expr) { return nullptr; }
568 
569     // OneMinusDot = 1 - DotNI2
570     Literal oneLiteral{Position{}, 1.0, &DotNI2Expr->type()};
571     const IntrinsicArguments OneMinusDotArgs = {&oneLiteral, DotNI2Expr.get(), nullptr};
572     std::unique_ptr<Expression> OneMinusDotExpr= Intrinsics::evaluate_sub(context, OneMinusDotArgs);
573     if (!OneMinusDotExpr) { return nullptr; }
574 
575     // Eta2 = Eta * Eta
576     const IntrinsicArguments Eta2Args = {Eta, Eta, nullptr};
577     std::unique_ptr<Expression> Eta2Expr = Intrinsics::evaluate_mul(context, Eta2Args);
578     if (!Eta2Expr) { return nullptr; }
579 
580     // Eta2xDot = Eta2 * OneMinusDot
581     const IntrinsicArguments Eta2xDotArgs = {Eta2Expr.get(), OneMinusDotExpr.get(), nullptr};
582     std::unique_ptr<Expression> Eta2xDotExpr = Intrinsics::evaluate_mul(context, Eta2xDotArgs);
583     if (!Eta2xDotExpr) { return nullptr; }
584 
585     // K = 1.0 - Eta2xDot
586     const IntrinsicArguments KArgs = {&oneLiteral, Eta2xDotExpr.get(), nullptr};
587     std::unique_ptr<Expression> KExpr = Intrinsics::evaluate_sub(context, KArgs);
588     if (!KExpr || !KExpr->is<Literal>()) { return nullptr; }
589 
590     // When K < 0, Refract(I, N, Eta) = vec(0)
591     double kValue = KExpr->as<Literal>().value();
592     if (kValue < 0) {
593         constexpr double kZero[4] = {};
594         return ConstructorCompound::MakeFromConstants(context, I->fPosition, I->type(), kZero);
595     }
596 
597     // When K ≥ 0, Refract(I, N, Eta) = (I * Eta) - N * (Eta * Dot(N,I) + Sqrt(K))
598 
599     // EtaDot = Eta * DotNI
600     const IntrinsicArguments EtaDotArgs = {Eta, DotNIExpr.get(), nullptr};
601     std::unique_ptr<Expression> EtaDotExpr = Intrinsics::evaluate_mul(context, EtaDotArgs);
602     if (!EtaDotExpr) { return nullptr; }
603 
604     // EtaDotSqrt = EtaDot + Sqrt(K)
605     Literal sqrtKLiteral{Position{}, std::sqrt(kValue), &Eta->type()};
606     const IntrinsicArguments EtaDotSqrtArgs = {EtaDotExpr.get(), &sqrtKLiteral, nullptr};
607     std::unique_ptr<Expression> EtaDotSqrtExpr = Intrinsics::evaluate_add(context, EtaDotSqrtArgs);
608     if (!EtaDotSqrtExpr) { return nullptr; }
609 
610     // NxEDS = N * EtaDotSqrt
611     const IntrinsicArguments NxEDSArgs = {N, EtaDotSqrtExpr.get(), nullptr};
612     std::unique_ptr<Expression> NxEDSExpr = Intrinsics::evaluate_mul(context, NxEDSArgs);
613     if (!NxEDSExpr) { return nullptr; }
614 
615     // IEta = I * Eta
616     const IntrinsicArguments IEtaArgs = {I, Eta, nullptr};
617     std::unique_ptr<Expression> IEtaExpr = Intrinsics::evaluate_mul(context, IEtaArgs);
618     if (!IEtaExpr) { return nullptr; }
619 
620     // Refract = IEta - NxEDS
621     const IntrinsicArguments RefractArgs = {IEtaExpr.get(), NxEDSExpr.get(), nullptr};
622     return Intrinsics::evaluate_sub(context, RefractArgs);
623 }
624 
625 }  // namespace
626 }  // namespace Intrinsics
627 
extract_matrix(const Expression * expr,float mat[16])628 static void extract_matrix(const Expression* expr, float mat[16]) {
629     size_t numSlots = expr->type().slotCount();
630     for (size_t index = 0; index < numSlots; ++index) {
631         mat[index] = *expr->getConstantValue(index);
632     }
633 }
634 
optimize_intrinsic_call(const Context & context,Position pos,IntrinsicKind intrinsic,const ExpressionArray & argArray,const Type & returnType)635 static std::unique_ptr<Expression> optimize_intrinsic_call(const Context& context,
636                                                            Position pos,
637                                                            IntrinsicKind intrinsic,
638                                                            const ExpressionArray& argArray,
639                                                            const Type& returnType) {
640     // Replace constant variables with their literal values.
641     IntrinsicArguments arguments = {};
642     SkASSERT(SkToSizeT(argArray.size()) <= arguments.size());
643     for (int index = 0; index < argArray.size(); ++index) {
644         arguments[index] = ConstantFolder::GetConstantValueForVariable(*argArray[index]);
645     }
646 
647     auto Get = [&](int idx, int col) -> float {
648         return *arguments[idx]->getConstantValue(col);
649     };
650 
651     switch (intrinsic) {
652         // 8.1 : Angle and Trigonometry Functions
653         case k_radians_IntrinsicKind:
654             return evaluate_intrinsic<float>(context, arguments, returnType,
655                                              Intrinsics::evaluate_radians);
656         case k_degrees_IntrinsicKind:
657             return evaluate_intrinsic<float>(context, arguments, returnType,
658                                              Intrinsics::evaluate_degrees);
659         case k_sin_IntrinsicKind:
660             return evaluate_intrinsic<float>(context, arguments, returnType,
661                                              Intrinsics::evaluate_sin);
662         case k_cos_IntrinsicKind:
663             return evaluate_intrinsic<float>(context, arguments, returnType,
664                                              Intrinsics::evaluate_cos);
665         case k_tan_IntrinsicKind:
666             return evaluate_intrinsic<float>(context, arguments, returnType,
667                                              Intrinsics::evaluate_tan);
668         case k_sinh_IntrinsicKind:
669             return evaluate_intrinsic<float>(context, arguments, returnType,
670                                              Intrinsics::evaluate_sinh);
671         case k_cosh_IntrinsicKind:
672             return evaluate_intrinsic<float>(context, arguments, returnType,
673                                              Intrinsics::evaluate_cosh);
674         case k_tanh_IntrinsicKind:
675             return evaluate_intrinsic<float>(context, arguments, returnType,
676                                              Intrinsics::evaluate_tanh);
677         case k_asin_IntrinsicKind:
678             return evaluate_intrinsic<float>(context, arguments, returnType,
679                                              Intrinsics::evaluate_asin);
680         case k_acos_IntrinsicKind:
681             return evaluate_intrinsic<float>(context, arguments, returnType,
682                                              Intrinsics::evaluate_acos);
683         case k_atan_IntrinsicKind:
684             if (argArray.size() == 1) {
685                 return evaluate_intrinsic<float>(context, arguments, returnType,
686                                                  Intrinsics::evaluate_atan);
687             } else {
688                 return evaluate_pairwise_intrinsic(context, arguments, returnType,
689                                                    Intrinsics::evaluate_atan2);
690             }
691         case k_asinh_IntrinsicKind:
692             return evaluate_intrinsic<float>(context, arguments, returnType,
693                                              Intrinsics::evaluate_asinh);
694 
695         case k_acosh_IntrinsicKind:
696             return evaluate_intrinsic<float>(context, arguments, returnType,
697                                              Intrinsics::evaluate_acosh);
698         case k_atanh_IntrinsicKind:
699             return evaluate_intrinsic<float>(context, arguments, returnType,
700                                              Intrinsics::evaluate_atanh);
701         // 8.2 : Exponential Functions
702         case k_pow_IntrinsicKind:
703             return evaluate_pairwise_intrinsic(context, arguments, returnType,
704                                                Intrinsics::evaluate_pow);
705         case k_exp_IntrinsicKind:
706             return evaluate_intrinsic<float>(context, arguments, returnType,
707                                              Intrinsics::evaluate_exp);
708         case k_log_IntrinsicKind:
709             return evaluate_intrinsic<float>(context, arguments, returnType,
710                                              Intrinsics::evaluate_log);
711         case k_exp2_IntrinsicKind:
712             return evaluate_intrinsic<float>(context, arguments, returnType,
713                                              Intrinsics::evaluate_exp2);
714         case k_log2_IntrinsicKind:
715             return evaluate_intrinsic<float>(context, arguments, returnType,
716                                              Intrinsics::evaluate_log2);
717         case k_sqrt_IntrinsicKind:
718             return evaluate_intrinsic<float>(context, arguments, returnType,
719                                              Intrinsics::evaluate_sqrt);
720         case k_inversesqrt_IntrinsicKind:
721             return evaluate_intrinsic<float>(context, arguments, returnType,
722                                              Intrinsics::evaluate_inversesqrt);
723         // 8.3 : Common Functions
724         case k_abs_IntrinsicKind:
725             return evaluate_intrinsic_numeric(context, arguments, returnType,
726                                               Intrinsics::evaluate_abs);
727         case k_sign_IntrinsicKind:
728             return Intrinsics::evaluate_sign(context, arguments);
729 
730         case k_floor_IntrinsicKind:
731             return evaluate_intrinsic<float>(context, arguments, returnType,
732                                              Intrinsics::evaluate_floor);
733         case k_ceil_IntrinsicKind:
734             return evaluate_intrinsic<float>(context, arguments, returnType,
735                                              Intrinsics::evaluate_ceil);
736         case k_fract_IntrinsicKind:
737             return evaluate_intrinsic<float>(context, arguments, returnType,
738                                              Intrinsics::evaluate_fract);
739         case k_mod_IntrinsicKind:
740             return evaluate_pairwise_intrinsic(context, arguments, returnType,
741                                                Intrinsics::evaluate_mod);
742         case k_min_IntrinsicKind:
743             return evaluate_pairwise_intrinsic(context, arguments, returnType,
744                                                Intrinsics::evaluate_min);
745         case k_max_IntrinsicKind:
746             return evaluate_pairwise_intrinsic(context, arguments, returnType,
747                                                Intrinsics::evaluate_max);
748         case k_clamp_IntrinsicKind:
749             return evaluate_3_way_intrinsic(context, arguments, returnType,
750                                             Intrinsics::evaluate_clamp);
751         case k_fma_IntrinsicKind:
752             return evaluate_3_way_intrinsic(context, arguments, returnType,
753                                             Intrinsics::evaluate_fma);
754         case k_saturate_IntrinsicKind:
755             return evaluate_intrinsic<float>(context, arguments, returnType,
756                                              Intrinsics::evaluate_saturate);
757         case k_mix_IntrinsicKind:
758             if (arguments[2]->type().componentType().isBoolean()) {
759                 const SkSL::Type& numericType = arguments[0]->type().componentType();
760 
761                 if (numericType.isFloat()) {
762                     type_check_expression<float>(*arguments[0]);
763                     type_check_expression<float>(*arguments[1]);
764                 } else if (numericType.isInteger()) {
765                     type_check_expression<SKSL_INT>(*arguments[0]);
766                     type_check_expression<SKSL_INT>(*arguments[1]);
767                 } else if (numericType.isBoolean()) {
768                     type_check_expression<bool>(*arguments[0]);
769                     type_check_expression<bool>(*arguments[1]);
770                 } else {
771                     SkDEBUGFAILF("unsupported type %s", numericType.description().c_str());
772                     return nullptr;
773                 }
774                 return evaluate_n_way_intrinsic(context, arguments[0], arguments[1], arguments[2],
775                                                 returnType, Intrinsics::evaluate_mix);
776             } else {
777                 return evaluate_3_way_intrinsic(context, arguments, returnType,
778                                                 Intrinsics::evaluate_mix);
779             }
780         case k_step_IntrinsicKind:
781             return evaluate_pairwise_intrinsic(context, arguments, returnType,
782                                                Intrinsics::evaluate_step);
783         case k_smoothstep_IntrinsicKind:
784             return evaluate_3_way_intrinsic(context, arguments, returnType,
785                                             Intrinsics::evaluate_smoothstep);
786         case k_trunc_IntrinsicKind:
787             return evaluate_intrinsic<float>(context, arguments, returnType,
788                                              Intrinsics::evaluate_trunc);
789         case k_round_IntrinsicKind:      // GLSL `round` documents its rounding mode as unspecified
790         case k_roundEven_IntrinsicKind:  // and is allowed to behave identically to `roundEven`.
791             return evaluate_intrinsic<float>(context, arguments, returnType,
792                                              Intrinsics::evaluate_round);
793         case k_floatBitsToInt_IntrinsicKind:
794             return evaluate_intrinsic<float>(context, arguments, returnType,
795                                              Intrinsics::evaluate_floatBitsToInt);
796         case k_floatBitsToUint_IntrinsicKind:
797             return evaluate_intrinsic<float>(context, arguments, returnType,
798                                              Intrinsics::evaluate_floatBitsToUint);
799         case k_intBitsToFloat_IntrinsicKind:
800             return evaluate_intrinsic<SKSL_INT>(context, arguments, returnType,
801                                                 Intrinsics::evaluate_intBitsToFloat);
802         case k_uintBitsToFloat_IntrinsicKind:
803             return evaluate_intrinsic<SKSL_INT>(context, arguments, returnType,
804                                                 Intrinsics::evaluate_uintBitsToFloat);
805         // 8.4 : Floating-Point Pack and Unpack Functions
806         case k_packUnorm2x16_IntrinsicKind: {
807             auto Pack = [&](int n) -> unsigned int {
808                 float x = Get(0, n);
809                 return (int)std::round(Intrinsics::evaluate_clamp(x, 0.0, 1.0) * 65535.0);
810             };
811             const double packed = ((Pack(0) << 0)  & 0x0000FFFF) |
812                                   ((Pack(1) << 16) & 0xFFFF0000);
813             return ConstructorCompound::MakeFromConstants(context, arguments[0]->fPosition,
814                                                           *context.fTypes.fUInt, &packed);
815         }
816         case k_packSnorm2x16_IntrinsicKind: {
817             auto Pack = [&](int n) -> unsigned int {
818                 float x = Get(0, n);
819                 return (int)std::round(Intrinsics::evaluate_clamp(x, -1.0, 1.0) * 32767.0);
820             };
821             const double packed = ((Pack(0) << 0)  & 0x0000FFFF) |
822                                   ((Pack(1) << 16) & 0xFFFF0000);
823             return ConstructorCompound::MakeFromConstants(context, arguments[0]->fPosition,
824                                                           *context.fTypes.fUInt, &packed);
825         }
826         case k_packHalf2x16_IntrinsicKind: {
827             auto Pack = [&](int n) -> unsigned int {
828                 return SkFloatToHalf(Get(0, n));
829             };
830             const double packed = ((Pack(0) << 0)  & 0x0000FFFF) |
831                                   ((Pack(1) << 16) & 0xFFFF0000);
832             return ConstructorCompound::MakeFromConstants(context, arguments[0]->fPosition,
833                                                           *context.fTypes.fUInt, &packed);
834         }
835         case k_unpackUnorm2x16_IntrinsicKind: {
836             SKSL_INT x = *arguments[0]->getConstantValue(0);
837             uint16_t a = ((x >> 0)  & 0x0000FFFF);
838             uint16_t b = ((x >> 16) & 0x0000FFFF);
839             const double unpacked[2] = {double(a) / 65535.0,
840                                         double(b) / 65535.0};
841             return ConstructorCompound::MakeFromConstants(context, arguments[0]->fPosition,
842                                                           *context.fTypes.fFloat2, unpacked);
843         }
844         case k_unpackSnorm2x16_IntrinsicKind: {
845             SKSL_INT x = *arguments[0]->getConstantValue(0);
846             int16_t a = ((x >> 0)  & 0x0000FFFF);
847             int16_t b = ((x >> 16) & 0x0000FFFF);
848             const double unpacked[2] = {Intrinsics::evaluate_clamp(double(a) / 32767.0, -1.0, 1.0),
849                                         Intrinsics::evaluate_clamp(double(b) / 32767.0, -1.0, 1.0)};
850             return ConstructorCompound::MakeFromConstants(context, arguments[0]->fPosition,
851                                                           *context.fTypes.fFloat2, unpacked);
852         }
853         case k_unpackHalf2x16_IntrinsicKind: {
854             SKSL_INT x = *arguments[0]->getConstantValue(0);
855             uint16_t a = ((x >> 0)  & 0x0000FFFF);
856             uint16_t b = ((x >> 16) & 0x0000FFFF);
857             const double unpacked[2] = {SkHalfToFloat(a),
858                                         SkHalfToFloat(b)};
859             return ConstructorCompound::MakeFromConstants(context, arguments[0]->fPosition,
860                                                           *context.fTypes.fFloat2, unpacked);
861         }
862         // 8.5 : Geometric Functions
863         case k_length_IntrinsicKind:
864             return Intrinsics::evaluate_length(arguments);
865 
866         case k_distance_IntrinsicKind:
867             return Intrinsics::evaluate_distance(arguments);
868 
869         case k_dot_IntrinsicKind:
870             return Intrinsics::evaluate_dot(arguments);
871 
872         case k_cross_IntrinsicKind: {
873             auto X = [&](int n) -> float { return Get(0, n); };
874             auto Y = [&](int n) -> float { return Get(1, n); };
875             SkASSERT(arguments[0]->type().columns() == 3);  // the vec2 form is not a real intrinsic
876 
877             double vec[3] = {X(1) * Y(2) - Y(1) * X(2),
878                              X(2) * Y(0) - Y(2) * X(0),
879                              X(0) * Y(1) - Y(0) * X(1)};
880             return ConstructorCompound::MakeFromConstants(context, arguments[0]->fPosition,
881                                                           returnType, vec);
882         }
883         case k_normalize_IntrinsicKind:
884             return Intrinsics::evaluate_normalize(context, arguments);
885 
886         case k_faceforward_IntrinsicKind:
887             return Intrinsics::evaluate_faceforward(context, arguments);
888 
889         case k_reflect_IntrinsicKind:
890             return Intrinsics::evaluate_reflect(context, arguments);
891 
892         case k_refract_IntrinsicKind:
893             return Intrinsics::evaluate_refract(context, arguments);
894 
895         // 8.6 : Matrix Functions
896         case k_matrixCompMult_IntrinsicKind:
897             return evaluate_pairwise_intrinsic(context, arguments, returnType,
898                                                Intrinsics::evaluate_matrixCompMult);
899         case k_transpose_IntrinsicKind: {
900             double mat[16];
901             int index = 0;
902             for (int c = 0; c < returnType.columns(); ++c) {
903                 for (int r = 0; r < returnType.rows(); ++r) {
904                     mat[index++] = Get(0, (returnType.columns() * r) + c);
905                 }
906             }
907             return ConstructorCompound::MakeFromConstants(context, arguments[0]->fPosition,
908                                                           returnType, mat);
909         }
910         case k_outerProduct_IntrinsicKind: {
911             double mat[16];
912             int index = 0;
913             for (int c = 0; c < returnType.columns(); ++c) {
914                 for (int r = 0; r < returnType.rows(); ++r) {
915                     mat[index++] = Get(0, r) * Get(1, c);
916                 }
917             }
918             return ConstructorCompound::MakeFromConstants(context, arguments[0]->fPosition,
919                                                           returnType, mat);
920         }
921         case k_determinant_IntrinsicKind: {
922             float mat[16];
923             extract_matrix(arguments[0], mat);
924             float determinant;
925             switch (arguments[0]->type().slotCount()) {
926                 case 4:
927                     determinant = SkInvert2x2Matrix(mat, /*outMatrix=*/nullptr);
928                     break;
929                 case 9:
930                     determinant = SkInvert3x3Matrix(mat, /*outMatrix=*/nullptr);
931                     break;
932                 case 16:
933                     determinant = SkInvert4x4Matrix(mat, /*outMatrix=*/nullptr);
934                     break;
935                 default:
936                     SkDEBUGFAILF("unsupported type %s", arguments[0]->type().description().c_str());
937                     return nullptr;
938             }
939             return Literal::MakeFloat(arguments[0]->fPosition, determinant, &returnType);
940         }
941         case k_inverse_IntrinsicKind: {
942             float mat[16] = {};
943             extract_matrix(arguments[0], mat);
944             switch (arguments[0]->type().slotCount()) {
945                 case 4:
946                     if (SkInvert2x2Matrix(mat, mat) == 0.0f) {
947                         return nullptr;
948                     }
949                     break;
950                 case 9:
951                     if (SkInvert3x3Matrix(mat, mat) == 0.0f) {
952                         return nullptr;
953                     }
954                     break;
955                 case 16:
956                     if (SkInvert4x4Matrix(mat, mat) == 0.0f) {
957                         return nullptr;
958                     }
959                     break;
960                 default:
961                     SkDEBUGFAILF("unsupported type %s", arguments[0]->type().description().c_str());
962                     return nullptr;
963             }
964 
965             double dmat[16];
966             std::copy(mat, mat + std::size(mat), dmat);
967             return ConstructorCompound::MakeFromConstants(context, arguments[0]->fPosition,
968                                                          returnType, dmat);
969         }
970         // 8.7 : Vector Relational Functions
971         case k_lessThan_IntrinsicKind:
972             return optimize_comparison(context, arguments, Intrinsics::compare_lessThan);
973 
974         case k_lessThanEqual_IntrinsicKind:
975             return optimize_comparison(context, arguments, Intrinsics::compare_lessThanEqual);
976 
977         case k_greaterThan_IntrinsicKind:
978             return optimize_comparison(context, arguments, Intrinsics::compare_greaterThan);
979 
980         case k_greaterThanEqual_IntrinsicKind:
981             return optimize_comparison(context, arguments, Intrinsics::compare_greaterThanEqual);
982 
983         case k_equal_IntrinsicKind:
984             return optimize_comparison(context, arguments, Intrinsics::compare_equal);
985 
986         case k_notEqual_IntrinsicKind:
987             return optimize_comparison(context, arguments, Intrinsics::compare_notEqual);
988 
989         case k_any_IntrinsicKind:
990             return coalesce_vector<bool>(arguments, /*startingState=*/false, returnType,
991                                          Intrinsics::coalesce_any,
992                                          /*finalize=*/nullptr);
993         case k_all_IntrinsicKind:
994             return coalesce_vector<bool>(arguments, /*startingState=*/true, returnType,
995                                          Intrinsics::coalesce_all,
996                                          /*finalize=*/nullptr);
997         case k_not_IntrinsicKind:
998             return evaluate_intrinsic<bool>(context, arguments, returnType,
999                                             Intrinsics::evaluate_not);
1000         default:
1001             return nullptr;
1002     }
1003 }
1004 
clone(Position pos) const1005 std::unique_ptr<Expression> FunctionCall::clone(Position pos) const {
1006     return std::make_unique<FunctionCall>(pos, &this->type(), &this->function(),
1007                                           this->arguments().clone(), this->stablePointer());
1008 }
1009 
description(OperatorPrecedence) const1010 std::string FunctionCall::description(OperatorPrecedence) const {
1011     std::string result = std::string(this->function().name()) + "(";
1012     auto separator = SkSL::String::Separator();
1013     for (const std::unique_ptr<Expression>& arg : this->arguments()) {
1014         result += separator();
1015         result += arg->description(OperatorPrecedence::kSequence);
1016     }
1017     result += ")";
1018     return result;
1019 }
1020 
argument_and_parameter_flags_match(const Expression & argument,const Variable & parameter)1021 static bool argument_and_parameter_flags_match(const Expression& argument,
1022                                                const Variable& parameter) {
1023     // If the function parameter has a pixel format, the argument being passed in must have a
1024     // matching pixel format.
1025     LayoutFlags paramPixelFormat = parameter.layout().fFlags & LayoutFlag::kAllPixelFormats;
1026     if (paramPixelFormat != LayoutFlag::kNone) {
1027         // The only SkSL type that supports pixel-format qualifiers is a storage texture.
1028         if (parameter.type().isStorageTexture()) {
1029             // Storage textures are opaquely typed, so there's no way to specify one other than by
1030             // directly accessing a variable.
1031             if (!argument.is<VariableReference>()) {
1032                 return false;
1033             }
1034 
1035             // The variable's pixel-format flags must match. (Only one pixel-format bit can be set.)
1036             const Variable& var = *argument.as<VariableReference>().variable();
1037             if ((var.layout().fFlags & LayoutFlag::kAllPixelFormats) != paramPixelFormat) {
1038                 return false;
1039             }
1040         }
1041     }
1042 
1043     // The only other supported parameter flags are `const` and `in/out`, which do not allow
1044     // multiple overloads.
1045     return true;
1046 }
1047 
1048 /**
1049  * Used to determine the best overload for a function call by calculating the cost of coercing the
1050  * arguments of the function to the required types. Cost has no particular meaning other than "lower
1051  * costs are preferred". Returns CoercionCost::Impossible() if the call is not valid. This is never
1052  * called for functions with only one definition.
1053  */
call_cost(const Context & context,const FunctionDeclaration & function,const ExpressionArray & arguments)1054 static CoercionCost call_cost(const Context& context,
1055                               const FunctionDeclaration& function,
1056                               const ExpressionArray& arguments) {
1057     // Strict-ES2 programs can never call an `$es3` function.
1058     if (context.fConfig->strictES2Mode() && function.modifierFlags().isES3()) {
1059         return CoercionCost::Impossible();
1060     }
1061     // Functions with the wrong number of parameters are never a match.
1062     if (function.parameters().size() != SkToSizeT(arguments.size())) {
1063         return CoercionCost::Impossible();
1064     }
1065     // If the arguments cannot be coerced to the parameter types, the function is never a match.
1066     FunctionDeclaration::ParamTypes types;
1067     const Type* ignored;
1068     if (!function.determineFinalTypes(arguments, &types, &ignored)) {
1069         return CoercionCost::Impossible();
1070     }
1071     // If the arguments do not match the parameter types due to mismatched modifiers, the function
1072     // is never a match.
1073     for (int i = 0; i < arguments.size(); i++) {
1074         const Expression& arg = *arguments[i];
1075         const Variable& param = *function.parameters()[i];
1076         if (!argument_and_parameter_flags_match(arg, param)) {
1077             return CoercionCost::Impossible();
1078         }
1079     }
1080     // Return the sum of coercion costs of each argument.
1081     CoercionCost total = CoercionCost::Free();
1082     for (int i = 0; i < arguments.size(); i++) {
1083         total = total + arguments[i]->coercionCost(*types[i]);
1084     }
1085     return total;
1086 }
1087 
FindBestFunctionForCall(const Context & context,const FunctionDeclaration * overloadChain,const ExpressionArray & arguments)1088 const FunctionDeclaration* FunctionCall::FindBestFunctionForCall(
1089         const Context& context,
1090         const FunctionDeclaration* overloadChain,
1091         const ExpressionArray& arguments) {
1092     if (!overloadChain->nextOverload()) {
1093         return overloadChain;
1094     }
1095     CoercionCost bestCost = CoercionCost::Impossible();
1096     const FunctionDeclaration* best = nullptr;
1097     for (const FunctionDeclaration* f = overloadChain; f != nullptr; f = f->nextOverload()) {
1098         CoercionCost cost = call_cost(context, *f, arguments);
1099         if (cost <= bestCost) {
1100             bestCost = cost;
1101             best = f;
1102         }
1103     }
1104     return bestCost.fImpossible ? nullptr : best;
1105 }
1106 
build_argument_type_list(SkSpan<const std::unique_ptr<Expression>> arguments)1107 static std::string build_argument_type_list(SkSpan<const std::unique_ptr<Expression>> arguments) {
1108     std::string result = "(";
1109     auto separator = SkSL::String::Separator();
1110     for (const std::unique_ptr<Expression>& arg : arguments) {
1111         result += separator();
1112         result += arg->type().displayName();
1113     }
1114     return result + ")";
1115 }
1116 
Convert(const Context & context,Position pos,std::unique_ptr<Expression> functionValue,ExpressionArray arguments)1117 std::unique_ptr<Expression> FunctionCall::Convert(const Context& context,
1118                                                   Position pos,
1119                                                   std::unique_ptr<Expression> functionValue,
1120                                                   ExpressionArray arguments) {
1121     switch (functionValue->kind()) {
1122         case Expression::Kind::kTypeReference:
1123             return Constructor::Convert(context,
1124                                         pos,
1125                                         functionValue->as<TypeReference>().value(),
1126                                         std::move(arguments));
1127         case Expression::Kind::kFunctionReference: {
1128             const FunctionReference& ref = functionValue->as<FunctionReference>();
1129             const FunctionDeclaration* best = FindBestFunctionForCall(context, ref.overloadChain(),
1130                                                                       arguments);
1131             if (best) {
1132                 return FunctionCall::Convert(context, pos, *best, std::move(arguments));
1133             }
1134             std::string msg = "no match for " + std::string(ref.overloadChain()->name()) +
1135                               build_argument_type_list(arguments);
1136             context.fErrors->error(pos, msg);
1137             return nullptr;
1138         }
1139         case Expression::Kind::kMethodReference: {
1140             MethodReference& ref = functionValue->as<MethodReference>();
1141             arguments.push_back(std::move(ref.self()));
1142 
1143             const FunctionDeclaration* best = FindBestFunctionForCall(context, ref.overloadChain(),
1144                                                                       arguments);
1145             if (best) {
1146                 return FunctionCall::Convert(context, pos, *best, std::move(arguments));
1147             }
1148             std::string msg =
1149                     "no match for " + arguments.back()->type().displayName() +
1150                     "::" + std::string(ref.overloadChain()->name().substr(1)) +
1151                     build_argument_type_list(SkSpan(arguments).first(arguments.size() - 1));
1152             context.fErrors->error(pos, msg);
1153             return nullptr;
1154         }
1155         case Expression::Kind::kPoison:
1156             functionValue->fPosition = pos;
1157             return functionValue;
1158         default:
1159             context.fErrors->error(pos, "not a function");
1160             return nullptr;
1161     }
1162 }
1163 
Convert(const Context & context,Position pos,const FunctionDeclaration & function,ExpressionArray arguments)1164 std::unique_ptr<Expression> FunctionCall::Convert(const Context& context,
1165                                                   Position pos,
1166                                                   const FunctionDeclaration& function,
1167                                                   ExpressionArray arguments) {
1168     // Reject ES3 function calls in strict ES2 mode.
1169     if (context.fConfig->strictES2Mode() && function.modifierFlags().isES3()) {
1170         context.fErrors->error(pos, "call to '" + function.description() + "' is not supported");
1171         return nullptr;
1172     }
1173 
1174     // Reject function calls with the wrong number of arguments.
1175     if (function.parameters().size() != SkToSizeT(arguments.size())) {
1176         std::string msg = "call to '" + std::string(function.name()) + "' expected " +
1177                           std::to_string(function.parameters().size()) + " argument";
1178         if (function.parameters().size() != 1) {
1179             msg += "s";
1180         }
1181         msg += ", but found " + std::to_string(arguments.size());
1182         context.fErrors->error(pos, msg);
1183         return nullptr;
1184     }
1185 
1186     // If the arguments do not match the parameter types due to mismatched modifiers, reject the
1187     // function call.
1188     for (int i = 0; i < arguments.size(); i++) {
1189         const Expression& arg = *arguments[i];
1190         const Variable& param = *function.parameters()[i];
1191         if (!argument_and_parameter_flags_match(arg, param)) {
1192             context.fErrors->error(arg.position(), "expected argument of type '" +
1193                                                    param.layout().paddedDescription() +
1194                                                    param.modifierFlags().paddedDescription() +
1195                                                    param.type().description() + "'");
1196             return nullptr;
1197         }
1198     }
1199 
1200     // Resolve generic types.
1201     FunctionDeclaration::ParamTypes types;
1202     const Type* returnType;
1203     if (!function.determineFinalTypes(arguments, &types, &returnType)) {
1204         std::string msg = "no match for " + std::string(function.name()) +
1205                           build_argument_type_list(arguments);
1206         context.fErrors->error(pos, msg);
1207         return nullptr;
1208     }
1209 
1210     for (int i = 0; i < arguments.size(); i++) {
1211         // Coerce each argument to the proper type.
1212         arguments[i] = types[i]->coerceExpression(std::move(arguments[i]), context);
1213         if (!arguments[i]) {
1214             return nullptr;
1215         }
1216         // Update the refKind on out-parameters, and ensure that they are actually assignable.
1217         ModifierFlags paramFlags = function.parameters()[i]->modifierFlags();
1218         if (paramFlags & ModifierFlag::kOut) {
1219             const VariableRefKind refKind = (paramFlags & ModifierFlag::kIn)
1220                                                     ? VariableReference::RefKind::kReadWrite
1221                                                     : VariableReference::RefKind::kPointer;
1222             if (!Analysis::UpdateVariableRefKind(arguments[i].get(), refKind, context.fErrors)) {
1223                 return nullptr;
1224             }
1225         }
1226     }
1227 
1228     if (function.isMain()) {
1229         context.fErrors->error(pos, "call to 'main' is not allowed");
1230         return nullptr;
1231     }
1232 
1233     if (function.intrinsicKind() == k_eval_IntrinsicKind) {
1234         // This is a method call on an effect child. Translate it into a ChildCall, which simplifies
1235         // handling in the generators and analysis code.
1236         const Variable& child = *arguments.back()->as<VariableReference>().variable();
1237         arguments.pop_back();
1238         return ChildCall::Make(context, pos, returnType, child, std::move(arguments));
1239     }
1240 
1241     return Make(context, pos, returnType, function, std::move(arguments));
1242 }
1243 
Make(const Context & context,Position pos,const Type * returnType,const FunctionDeclaration & function,ExpressionArray arguments)1244 std::unique_ptr<Expression> FunctionCall::Make(const Context& context,
1245                                                Position pos,
1246                                                const Type* returnType,
1247                                                const FunctionDeclaration& function,
1248                                                ExpressionArray arguments) {
1249     SkASSERT(function.parameters().size() == SkToSizeT(arguments.size()));
1250 
1251     // We might be able to optimize built-in intrinsics.
1252     if (function.isIntrinsic() && has_compile_time_constant_arguments(arguments)) {
1253         // The function is an intrinsic and all inputs are compile-time constants. Optimize it.
1254         if (std::unique_ptr<Expression> expr = optimize_intrinsic_call(context,
1255                                                                        pos,
1256                                                                        function.intrinsicKind(),
1257                                                                        arguments,
1258                                                                        *returnType)) {
1259             expr->fPosition = pos;
1260             return expr;
1261         }
1262     }
1263 
1264     return std::make_unique<FunctionCall>(pos, returnType, &function, std::move(arguments),
1265                                           /*stablePointer=*/nullptr);
1266 }
1267 
1268 }  // namespace SkSL
1269