1 //
2 // Copyright 2023 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include "compiler/translator/tree_ops/msl/RewriteInterpolants.h"
8
9 #include "compiler/translator/StaticType.h"
10 #include "compiler/translator/msl/AstHelpers.h"
11 #include "compiler/translator/tree_util/BuiltIn.h"
12 #include "compiler/translator/tree_util/IntermNode_util.h"
13 #include "compiler/translator/tree_util/ReplaceVariable.h"
14
15 namespace sh
16 {
17
18 namespace
19 {
20
21 class FindInterpolantsTraverser : public TIntermTraverser
22 {
23 public:
FindInterpolantsTraverser(TSymbolTable * symbolTable,const DriverUniformMetal * driverUniforms)24 FindInterpolantsTraverser(TSymbolTable *symbolTable, const DriverUniformMetal *driverUniforms)
25 : TIntermTraverser(true, false, false, symbolTable),
26 mDriverUniforms(driverUniforms),
27 mUsesSampleInterpolation(false)
28 {}
29
visitDeclaration(Visit,TIntermDeclaration * node)30 bool visitDeclaration(Visit, TIntermDeclaration *node) override
31 {
32 const TIntermSequence &sequence = *(node->getSequence());
33 ASSERT(!sequence.empty());
34
35 const TIntermTyped &typedNode = *(sequence.front()->getAsTyped());
36 TQualifier qualifier = typedNode.getQualifier();
37 if (qualifier == EvqSampleIn || qualifier == EvqNoPerspectiveSampleIn)
38 {
39 mUsesSampleInterpolation = true;
40 }
41
42 return true;
43 }
44
getFlipFunction()45 const TFunction *getFlipFunction()
46 {
47 if (mFlipFunction != nullptr)
48 {
49 return mFlipFunction->getFunction();
50 }
51
52 const TType *vec2Type = StaticType::GetQualified<EbtFloat, EbpHigh, EvqParamIn, 2>();
53 TVariable *offsetParam = new TVariable(mSymbolTable, ImmutableString("offset"), vec2Type,
54 SymbolType::AngleInternal);
55 TFunction *function =
56 new TFunction(mSymbolTable, ImmutableString("ANGLEFlipInterpolationOffset"),
57 SymbolType::AngleInternal, vec2Type, true);
58 function->addParameter(offsetParam);
59
60 TIntermTyped *flipXY =
61 mDriverUniforms->getFlipXY(mSymbolTable, DriverUniformFlip::Fragment);
62 TIntermTyped *flipped = new TIntermBinary(EOpMul, new TIntermSymbol(offsetParam), flipXY);
63 TIntermBranch *returnStatement = new TIntermBranch(EOpReturn, flipped);
64
65 TIntermBlock *body = new TIntermBlock;
66 body->appendStatement(returnStatement);
67
68 mFlipFunction = new TIntermFunctionDefinition(new TIntermFunctionPrototype(function), body);
69 return function;
70 }
71
visitAggregate(Visit visit,TIntermAggregate * node)72 bool visitAggregate(Visit visit, TIntermAggregate *node) override
73 {
74 if (!BuiltInGroup::IsInterpolationFS(node->getOp()))
75 {
76 return true;
77 }
78
79 TIntermNode *operand = node->getSequence()->at(0);
80 ASSERT(operand);
81
82 // For all of the interpolation functions, <interpolant> must be an input
83 // variable or an element of an input variable declared as an array.
84 const TIntermSymbol *symbolNode = operand->getAsSymbolNode();
85 if (!symbolNode)
86 {
87 const TIntermBinary *binaryNode = operand->getAsBinaryNode();
88 if (binaryNode &&
89 (binaryNode->getOp() == EOpIndexDirect || binaryNode->getOp() == EOpIndexIndirect))
90 {
91 symbolNode = binaryNode->getLeft()->getAsSymbolNode();
92 }
93 }
94 ASSERT(symbolNode);
95
96 // If <interpolant> is declared with a "flat" qualifier, the interpolated
97 // value will have the same value everywhere for a single primitive, so
98 // the location used for the interpolation has no effect and the functions
99 // just return that same value.
100 const TVariable *variable = &symbolNode->variable();
101 if (variable->getType().getQualifier() != EvqFlatIn)
102 {
103 mInterpolants.insert(variable);
104 }
105
106 // Flip offset's Y if needed.
107 if (node->getOp() == EOpInterpolateAtOffset)
108 {
109 TIntermTyped *offsetNode = node->getSequence()->at(1)->getAsTyped();
110 TIntermTyped *correctedOffset = TIntermAggregate::CreateFunctionCall(
111 *getFlipFunction(), new TIntermSequence{offsetNode});
112
113 queueReplacementWithParent(node, offsetNode, correctedOffset, OriginalNode::IS_DROPPED);
114 }
115
116 return true;
117 }
118
usesSampleInterpolation() const119 bool usesSampleInterpolation() const { return mUsesSampleInterpolation; }
120
getInterpolants() const121 const std::unordered_set<const TVariable *> &getInterpolants() const { return mInterpolants; }
122
getFlipFunctionDefinition()123 TIntermFunctionDefinition *getFlipFunctionDefinition() { return mFlipFunction; }
124
125 private:
126 const DriverUniformMetal *mDriverUniforms;
127
128 bool mUsesSampleInterpolation;
129 std::unordered_set<const TVariable *> mInterpolants;
130 TIntermFunctionDefinition *mFlipFunction = nullptr;
131 };
132
133 class WrapInterpolantsTraverser : public TIntermTraverser
134 {
135 public:
WrapInterpolantsTraverser(TSymbolTable * symbolTable)136 WrapInterpolantsTraverser(TSymbolTable *symbolTable)
137 : TIntermTraverser(true, false, false, symbolTable), mUsesSampleInterpolant(false)
138 {}
139
visitSymbol(TIntermSymbol * node)140 void visitSymbol(TIntermSymbol *node) override
141 {
142 // Skip all symbols not previously marked as
143 // interpolants by FindInterpolantsTraverser
144 const TType &type = node->variable().getType();
145 if (!type.isInterpolant())
146 {
147 return;
148 }
149
150 TIntermNode *ancestor = getAncestorNode(0);
151 ASSERT(ancestor);
152
153 // Only root-level input varying declarations should be
154 // reachable by this line and they must not be wrapped.
155 if (ancestor->getAsDeclarationNode())
156 {
157 return;
158 }
159
160 auto checkSkip = [](TIntermNode *node, TIntermNode *parentNode) {
161 if (TIntermAggregate *callNode = parentNode->getAsAggregate())
162 {
163 if (BuiltInGroup::IsInterpolationFS(callNode->getOp()) &&
164 callNode->getSequence()->at(0) == node)
165 {
166 return true;
167 }
168 }
169 return false;
170 };
171
172 // Skip symbols used as the first operand of interpolation functions
173 if (checkSkip(node, ancestor))
174 {
175 return;
176 }
177
178 TIntermNode *original = node;
179 if (TIntermBinary *binaryNode = ancestor->getAsBinaryNode())
180 {
181 if (binaryNode->getOp() == EOpIndexDirect || binaryNode->getOp() == EOpIndexIndirect)
182 {
183 ancestor = getAncestorNode(1);
184 ASSERT(ancestor);
185
186 // Skip array elements used as the first operand of interpolation functions
187 if (checkSkip(binaryNode, ancestor))
188 {
189 return;
190 }
191 original = binaryNode;
192 }
193 }
194
195 const char *functionName = nullptr;
196 TIntermSequence *arguments = new TIntermSequence{original};
197 switch (type.getQualifier())
198 {
199 case EvqFragmentIn:
200 case EvqSmoothIn:
201 case EvqNoPerspectiveIn:
202 // `metal::interpolant` variables cannot be used directly,
203 // so MSL has a dedicated interpolation function to obtain
204 // their pixel-center values. This function is included in
205 // the `MetalFragmentSample` built-in functions group.
206 functionName = "interpolateAtCenter";
207 break;
208 case EvqCentroidIn:
209 case EvqNoPerspectiveCentroidIn:
210 functionName = "interpolateAtCentroid";
211 break;
212 case EvqSampleIn:
213 case EvqNoPerspectiveSampleIn:
214 functionName = "interpolateAtSample";
215 arguments->push_back(new TIntermSymbol(BuiltInVariable::gl_SampleID()));
216 mUsesSampleInterpolant = true;
217 break;
218 default:
219 UNREACHABLE();
220 break;
221 }
222 TIntermTyped *replacement = CreateBuiltInFunctionCallNode(
223 functionName, arguments, *mSymbolTable, kESSLInternalBackendBuiltIns);
224
225 queueReplacementWithParent(ancestor, original, replacement, OriginalNode::BECOMES_CHILD);
226 }
227
usesSampleInterpolant() const228 bool usesSampleInterpolant() const { return mUsesSampleInterpolant; }
229
230 private:
231 bool mUsesSampleInterpolant;
232 };
233
234 } // anonymous namespace
235
RewriteInterpolants(TCompiler & compiler,TIntermBlock & root,TSymbolTable & symbolTable,const DriverUniformMetal * driverUniforms,bool * outUsesSampleInterpolation,bool * outUsesSampleInterpolant)236 [[nodiscard]] bool RewriteInterpolants(TCompiler &compiler,
237 TIntermBlock &root,
238 TSymbolTable &symbolTable,
239 const DriverUniformMetal *driverUniforms,
240 bool *outUsesSampleInterpolation,
241 bool *outUsesSampleInterpolant)
242 {
243 // Find all fragment inputs used with interpolation functions.
244 FindInterpolantsTraverser findInterpolantsTraverser(&symbolTable, driverUniforms);
245 root.traverse(&findInterpolantsTraverser);
246
247 // Define ANGLEFlipInterpolationOffset if interpolateAtOffset was used.
248 if (findInterpolantsTraverser.getFlipFunctionDefinition() != nullptr)
249 {
250 const size_t firstFunctionIndex = FindFirstFunctionDefinitionIndex(&root);
251 root.insertStatement(firstFunctionIndex,
252 findInterpolantsTraverser.getFlipFunctionDefinition());
253 }
254
255 if (!findInterpolantsTraverser.updateTree(&compiler, &root))
256 {
257 return false;
258 }
259 *outUsesSampleInterpolation = findInterpolantsTraverser.usesSampleInterpolation();
260
261 // Skip further operations when interpolation functions are not used.
262 if (findInterpolantsTraverser.getInterpolants().empty())
263 {
264 return true;
265 }
266
267 // Adjust variable types as per MSL requirements
268 //
269 // * Inputs with omitted and smooth interpolation qualifiers will be written as
270 // metal::interpolant<T, metal::interpolation::perspective>
271 //
272 // * Inputs with noperspective interpolation qualifiers will be written as
273 // metal::interpolant<T, metal::interpolation::no_perspective>
274 for (const TVariable *var : findInterpolantsTraverser.getInterpolants())
275 {
276 TType *replacementType = new TType(var->getType());
277 replacementType->setInterpolant(true);
278 TVariable *replacement =
279 new TVariable(&symbolTable, var->name(), replacementType, var->symbolType());
280 if (!ReplaceVariable(&compiler, &root, var, replacement))
281 {
282 return false;
283 }
284 }
285
286 // Wrap direct usages of interpolants with explicit interpolation
287 // functions depending on their auxiliary qualifiers
288 // in vec4 interpolant -> ANGLE_interpolateAtCenter(interpolant)
289 // centroid in vec4 interpolant -> ANGLE_interpolateAtCentroid(interpolant)
290 // sample in vec4 interpolant -> ANGLE_interpolateAtSample(interpolant, gl_SampleID)
291 WrapInterpolantsTraverser wrapInterpolantsTraverser(&symbolTable);
292 root.traverse(&wrapInterpolantsTraverser);
293 if (!wrapInterpolantsTraverser.updateTree(&compiler, &root))
294 {
295 return false;
296 }
297 *outUsesSampleInterpolant = wrapInterpolantsTraverser.usesSampleInterpolant();
298
299 return true;
300 }
301
302 } // namespace sh
303