xref: /aosp_15_r20/external/skia/src/sksl/analysis/SkSLSpecialization.h (revision c8dee2aa9b3f27cf6c858bd81872bdeb2c07ed17)
1 /*
2  * Copyright 2024 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #ifndef SKSL_SPECIALIZATION
9 #define SKSL_SPECIALIZATION
10 
11 #include "include/private/base/SkTArray.h"
12 #include "src/core/SkChecksum.h"
13 #include "src/core/SkTHash.h"
14 #include "src/utils/SkBitSet.h"
15 
16 #include <cstddef>
17 #include <functional>
18 
19 namespace SkSL {
20 
21 class Expression;
22 class FunctionCall;
23 class FunctionDeclaration;
24 class Variable;
25 struct Program;
26 
27 namespace Analysis {
28 
29 /**
30  * Design docs for SkSL function specialization: go/sksl-function-specialization
31  * https://docs.google.com/document/d/1dJdkk2-KmP-62EERzKygzsLLnxihCbDcFi3UHc1WzAM/edit?usp=sharing
32  */
33 
34 // The current index of the specialization function being walked through, used to
35 // track what the proper specialization function call should be if walking through a
36 // specialized function call stack.
37 using SpecializationIndex = int;
38 static constexpr SpecializationIndex kUnspecialized = -1;
39 
40 // Global uniforms used by a specialization,
41 // maps <function parameter, expression referencing global uniform>
42 using SpecializedParameters = skia_private::THashMap<const Variable*, const Expression*>;
43 // The set of specializated implementations needed for a given function.
44 using Specializations = skia_private::TArray<SpecializedParameters>;
45 // The full set of all specializations required by the program.
46 using SpecializationMap = skia_private::THashMap<const FunctionDeclaration*, Specializations>;
47 
48 // This can be used as a key into a map of specialized function declarations. Most backends which
49 // implement function specialization will have a need for this.
50 struct SpecializedFunctionKey {
51     struct Hash {
operatorSpecializedFunctionKey::Hash52         size_t operator()(const SpecializedFunctionKey& entry) {
53             return SkGoodHash()(entry.fDeclaration) ^
54                    SkGoodHash()(entry.fSpecializationIndex);
55         }
56     };
57 
58     bool operator==(const SpecializedFunctionKey& other) const {
59         return fDeclaration == other.fDeclaration &&
60                fSpecializationIndex == other.fSpecializationIndex;
61     }
62 
63     const FunctionDeclaration* fDeclaration = nullptr;
64     SpecializationIndex fSpecializationIndex = Analysis::kUnspecialized;
65 };
66 
67 // This is used as a key into the SpecializedCallMap.
68 struct SpecializedCallKey {
69     struct Hash {
operatorSpecializedCallKey::Hash70         size_t operator()(const SpecializedCallKey& entry) {
71             return SkGoodHash()(entry.fStablePointer) ^
72                    SkGoodHash()(entry.fParentSpecializationIndex);
73         }
74     };
75 
76     bool operator==(const SpecializedCallKey& other) const {
77         return fStablePointer == other.fStablePointer &&
78                fParentSpecializationIndex == other.fParentSpecializationIndex;
79     }
80 
81     const FunctionCall* fStablePointer = nullptr;
82     SpecializationIndex fParentSpecializationIndex = Analysis::kUnspecialized;
83 };
84 
85 // The mapping of function calls and their inherited specialization to their corresponding
86 // specialization index in `Specializations`
87 using SpecializedCallMap = skia_private::THashMap<SpecializedCallKey,
88                                                   SpecializationIndex,
89                                                   SpecializedCallKey::Hash>;
90 struct SpecializationInfo {
91     SpecializationMap fSpecializationMap;
92     SpecializedCallMap fSpecializedCallMap;
93 };
94 
95 // A function that returns true if the parameter variable fits the criteria
96 // to create a specialization.
97 using ParameterMatchesFn = std::function<bool(const Variable&)>;
98 
99 // Finds functions that contain parameters that should be specialized on and writes the
100 // specialization info to the provided `SpecializationInfo`.
101 void FindFunctionsToSpecialize(const Program& program,
102                                SpecializationInfo* info,
103                                const ParameterMatchesFn& specializationFn);
104 
105 // Given a function call and the active specialization index, looks up the specialization index for
106 // the call target. In other words: in the specialization map, we first look up the call target's
107 // declaration, which yields a Specialization array. We would find the correct mappings in the array
108 // at the SpecializationIndex returned by this function.
109 SpecializationIndex FindSpecializationIndexForCall(const FunctionCall& call,
110                                                    const SpecializationInfo& info,
111                                                    SpecializationIndex activeSpecializationIndex);
112 
113 // Given a function, returns a bit-mask corresponding to each parameter. If a bit is set, the
114 // corresponding parameter is specialized and should be excluded from the argument/parameter list.
115 SkBitSet FindSpecializedParametersForFunction(const FunctionDeclaration& func,
116                                               const SpecializationInfo& info);
117 
118 // Given a function and its specialization index, invokes a callback once per specialized parameter.
119 // The callback will be passed the parameter's index, the parameter variable, and the specialized
120 // value at the given specialization index.
121 using ParameterMappingCallback = std::function<void(int paramIndex,
122                                                     const Variable* param,
123                                                     const Expression* value)>;
124 
125 void GetParameterMappingsForFunction(const FunctionDeclaration& func,
126                                      const SpecializationInfo& info,
127                                      SpecializationIndex specializationIndex,
128                                      const ParameterMappingCallback& callback);
129 
130 }  // namespace Analysis
131 }  // namespace SkSL
132 
133 #endif
134