xref: /aosp_15_r20/external/angle/src/compiler/translator/tree_ops/msl/RewriteInterpolants.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2023 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/tree_ops/msl/RewriteInterpolants.h"
8 
9 #include "compiler/translator/StaticType.h"
10 #include "compiler/translator/msl/AstHelpers.h"
11 #include "compiler/translator/tree_util/BuiltIn.h"
12 #include "compiler/translator/tree_util/IntermNode_util.h"
13 #include "compiler/translator/tree_util/ReplaceVariable.h"
14 
15 namespace sh
16 {
17 
18 namespace
19 {
20 
21 class FindInterpolantsTraverser : public TIntermTraverser
22 {
23   public:
FindInterpolantsTraverser(TSymbolTable * symbolTable,const DriverUniformMetal * driverUniforms)24     FindInterpolantsTraverser(TSymbolTable *symbolTable, const DriverUniformMetal *driverUniforms)
25         : TIntermTraverser(true, false, false, symbolTable),
26           mDriverUniforms(driverUniforms),
27           mUsesSampleInterpolation(false)
28     {}
29 
visitDeclaration(Visit,TIntermDeclaration * node)30     bool visitDeclaration(Visit, TIntermDeclaration *node) override
31     {
32         const TIntermSequence &sequence = *(node->getSequence());
33         ASSERT(!sequence.empty());
34 
35         const TIntermTyped &typedNode = *(sequence.front()->getAsTyped());
36         TQualifier qualifier          = typedNode.getQualifier();
37         if (qualifier == EvqSampleIn || qualifier == EvqNoPerspectiveSampleIn)
38         {
39             mUsesSampleInterpolation = true;
40         }
41 
42         return true;
43     }
44 
getFlipFunction()45     const TFunction *getFlipFunction()
46     {
47         if (mFlipFunction != nullptr)
48         {
49             return mFlipFunction->getFunction();
50         }
51 
52         const TType *vec2Type  = StaticType::GetQualified<EbtFloat, EbpHigh, EvqParamIn, 2>();
53         TVariable *offsetParam = new TVariable(mSymbolTable, ImmutableString("offset"), vec2Type,
54                                                SymbolType::AngleInternal);
55         TFunction *function =
56             new TFunction(mSymbolTable, ImmutableString("ANGLEFlipInterpolationOffset"),
57                           SymbolType::AngleInternal, vec2Type, true);
58         function->addParameter(offsetParam);
59 
60         TIntermTyped *flipXY =
61             mDriverUniforms->getFlipXY(mSymbolTable, DriverUniformFlip::Fragment);
62         TIntermTyped *flipped = new TIntermBinary(EOpMul, new TIntermSymbol(offsetParam), flipXY);
63         TIntermBranch *returnStatement = new TIntermBranch(EOpReturn, flipped);
64 
65         TIntermBlock *body = new TIntermBlock;
66         body->appendStatement(returnStatement);
67 
68         mFlipFunction = new TIntermFunctionDefinition(new TIntermFunctionPrototype(function), body);
69         return function;
70     }
71 
visitAggregate(Visit visit,TIntermAggregate * node)72     bool visitAggregate(Visit visit, TIntermAggregate *node) override
73     {
74         if (!BuiltInGroup::IsInterpolationFS(node->getOp()))
75         {
76             return true;
77         }
78 
79         TIntermNode *operand = node->getSequence()->at(0);
80         ASSERT(operand);
81 
82         // For all of the interpolation functions, <interpolant> must be an input
83         // variable or an element of an input variable declared as an array.
84         const TIntermSymbol *symbolNode = operand->getAsSymbolNode();
85         if (!symbolNode)
86         {
87             const TIntermBinary *binaryNode = operand->getAsBinaryNode();
88             if (binaryNode &&
89                 (binaryNode->getOp() == EOpIndexDirect || binaryNode->getOp() == EOpIndexIndirect))
90             {
91                 symbolNode = binaryNode->getLeft()->getAsSymbolNode();
92             }
93         }
94         ASSERT(symbolNode);
95 
96         // If <interpolant> is declared with a "flat" qualifier, the interpolated
97         // value will have the same value everywhere for a single primitive, so
98         // the location used for the interpolation has no effect and the functions
99         // just return that same value.
100         const TVariable *variable = &symbolNode->variable();
101         if (variable->getType().getQualifier() != EvqFlatIn)
102         {
103             mInterpolants.insert(variable);
104         }
105 
106         // Flip offset's Y if needed.
107         if (node->getOp() == EOpInterpolateAtOffset)
108         {
109             TIntermTyped *offsetNode      = node->getSequence()->at(1)->getAsTyped();
110             TIntermTyped *correctedOffset = TIntermAggregate::CreateFunctionCall(
111                 *getFlipFunction(), new TIntermSequence{offsetNode});
112 
113             queueReplacementWithParent(node, offsetNode, correctedOffset, OriginalNode::IS_DROPPED);
114         }
115 
116         return true;
117     }
118 
usesSampleInterpolation() const119     bool usesSampleInterpolation() const { return mUsesSampleInterpolation; }
120 
getInterpolants() const121     const std::unordered_set<const TVariable *> &getInterpolants() const { return mInterpolants; }
122 
getFlipFunctionDefinition()123     TIntermFunctionDefinition *getFlipFunctionDefinition() { return mFlipFunction; }
124 
125   private:
126     const DriverUniformMetal *mDriverUniforms;
127 
128     bool mUsesSampleInterpolation;
129     std::unordered_set<const TVariable *> mInterpolants;
130     TIntermFunctionDefinition *mFlipFunction = nullptr;
131 };
132 
133 class WrapInterpolantsTraverser : public TIntermTraverser
134 {
135   public:
WrapInterpolantsTraverser(TSymbolTable * symbolTable)136     WrapInterpolantsTraverser(TSymbolTable *symbolTable)
137         : TIntermTraverser(true, false, false, symbolTable), mUsesSampleInterpolant(false)
138     {}
139 
visitSymbol(TIntermSymbol * node)140     void visitSymbol(TIntermSymbol *node) override
141     {
142         // Skip all symbols not previously marked as
143         // interpolants by FindInterpolantsTraverser
144         const TType &type = node->variable().getType();
145         if (!type.isInterpolant())
146         {
147             return;
148         }
149 
150         TIntermNode *ancestor = getAncestorNode(0);
151         ASSERT(ancestor);
152 
153         // Only root-level input varying declarations should be
154         // reachable by this line and they must not be wrapped.
155         if (ancestor->getAsDeclarationNode())
156         {
157             return;
158         }
159 
160         auto checkSkip = [](TIntermNode *node, TIntermNode *parentNode) {
161             if (TIntermAggregate *callNode = parentNode->getAsAggregate())
162             {
163                 if (BuiltInGroup::IsInterpolationFS(callNode->getOp()) &&
164                     callNode->getSequence()->at(0) == node)
165                 {
166                     return true;
167                 }
168             }
169             return false;
170         };
171 
172         // Skip symbols used as the first operand of interpolation functions
173         if (checkSkip(node, ancestor))
174         {
175             return;
176         }
177 
178         TIntermNode *original = node;
179         if (TIntermBinary *binaryNode = ancestor->getAsBinaryNode())
180         {
181             if (binaryNode->getOp() == EOpIndexDirect || binaryNode->getOp() == EOpIndexIndirect)
182             {
183                 ancestor = getAncestorNode(1);
184                 ASSERT(ancestor);
185 
186                 // Skip array elements used as the first operand of interpolation functions
187                 if (checkSkip(binaryNode, ancestor))
188                 {
189                     return;
190                 }
191                 original = binaryNode;
192             }
193         }
194 
195         const char *functionName   = nullptr;
196         TIntermSequence *arguments = new TIntermSequence{original};
197         switch (type.getQualifier())
198         {
199             case EvqFragmentIn:
200             case EvqSmoothIn:
201             case EvqNoPerspectiveIn:
202                 // `metal::interpolant` variables cannot be used directly,
203                 // so MSL has a dedicated interpolation function to obtain
204                 // their pixel-center values. This function is included in
205                 // the `MetalFragmentSample` built-in functions group.
206                 functionName = "interpolateAtCenter";
207                 break;
208             case EvqCentroidIn:
209             case EvqNoPerspectiveCentroidIn:
210                 functionName = "interpolateAtCentroid";
211                 break;
212             case EvqSampleIn:
213             case EvqNoPerspectiveSampleIn:
214                 functionName = "interpolateAtSample";
215                 arguments->push_back(new TIntermSymbol(BuiltInVariable::gl_SampleID()));
216                 mUsesSampleInterpolant = true;
217                 break;
218             default:
219                 UNREACHABLE();
220                 break;
221         }
222         TIntermTyped *replacement = CreateBuiltInFunctionCallNode(
223             functionName, arguments, *mSymbolTable, kESSLInternalBackendBuiltIns);
224 
225         queueReplacementWithParent(ancestor, original, replacement, OriginalNode::BECOMES_CHILD);
226     }
227 
usesSampleInterpolant() const228     bool usesSampleInterpolant() const { return mUsesSampleInterpolant; }
229 
230   private:
231     bool mUsesSampleInterpolant;
232 };
233 
234 }  // anonymous namespace
235 
RewriteInterpolants(TCompiler & compiler,TIntermBlock & root,TSymbolTable & symbolTable,const DriverUniformMetal * driverUniforms,bool * outUsesSampleInterpolation,bool * outUsesSampleInterpolant)236 [[nodiscard]] bool RewriteInterpolants(TCompiler &compiler,
237                                        TIntermBlock &root,
238                                        TSymbolTable &symbolTable,
239                                        const DriverUniformMetal *driverUniforms,
240                                        bool *outUsesSampleInterpolation,
241                                        bool *outUsesSampleInterpolant)
242 {
243     // Find all fragment inputs used with interpolation functions.
244     FindInterpolantsTraverser findInterpolantsTraverser(&symbolTable, driverUniforms);
245     root.traverse(&findInterpolantsTraverser);
246 
247     // Define ANGLEFlipInterpolationOffset if interpolateAtOffset was used.
248     if (findInterpolantsTraverser.getFlipFunctionDefinition() != nullptr)
249     {
250         const size_t firstFunctionIndex = FindFirstFunctionDefinitionIndex(&root);
251         root.insertStatement(firstFunctionIndex,
252                              findInterpolantsTraverser.getFlipFunctionDefinition());
253     }
254 
255     if (!findInterpolantsTraverser.updateTree(&compiler, &root))
256     {
257         return false;
258     }
259     *outUsesSampleInterpolation = findInterpolantsTraverser.usesSampleInterpolation();
260 
261     // Skip further operations when interpolation functions are not used.
262     if (findInterpolantsTraverser.getInterpolants().empty())
263     {
264         return true;
265     }
266 
267     // Adjust variable types as per MSL requirements
268     //
269     // * Inputs with omitted and smooth interpolation qualifiers will be written as
270     //       metal::interpolant<T, metal::interpolation::perspective>
271     //
272     // * Inputs with noperspective interpolation qualifiers will be written as
273     //       metal::interpolant<T, metal::interpolation::no_perspective>
274     for (const TVariable *var : findInterpolantsTraverser.getInterpolants())
275     {
276         TType *replacementType = new TType(var->getType());
277         replacementType->setInterpolant(true);
278         TVariable *replacement =
279             new TVariable(&symbolTable, var->name(), replacementType, var->symbolType());
280         if (!ReplaceVariable(&compiler, &root, var, replacement))
281         {
282             return false;
283         }
284     }
285 
286     // Wrap direct usages of interpolants with explicit interpolation
287     // functions depending on their auxiliary qualifiers
288     //            in vec4 interpolant -> ANGLE_interpolateAtCenter(interpolant)
289     //   centroid in vec4 interpolant -> ANGLE_interpolateAtCentroid(interpolant)
290     //     sample in vec4 interpolant -> ANGLE_interpolateAtSample(interpolant, gl_SampleID)
291     WrapInterpolantsTraverser wrapInterpolantsTraverser(&symbolTable);
292     root.traverse(&wrapInterpolantsTraverser);
293     if (!wrapInterpolantsTraverser.updateTree(&compiler, &root))
294     {
295         return false;
296     }
297     *outUsesSampleInterpolant = wrapInterpolantsTraverser.usesSampleInterpolant();
298 
299     return true;
300 }
301 
302 }  // namespace sh
303