xref: /aosp_15_r20/external/angle/src/compiler/translator/tree_ops/msl/RewriteUnaddressableReferences.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/tree_ops/msl/RewriteUnaddressableReferences.h"
8 #include "compiler/translator/AsNode.h"
9 #include "compiler/translator/IntermRebuild.h"
10 #include "compiler/translator/msl/AstHelpers.h"
11 
12 using namespace sh;
13 
14 namespace
15 {
16 
IsOutParam(const TType & paramType)17 bool IsOutParam(const TType &paramType)
18 {
19     const TQualifier qual = paramType.getQualifier();
20     switch (qual)
21     {
22         case TQualifier::EvqParamInOut:
23         case TQualifier::EvqParamOut:
24             return true;
25 
26         default:
27             return false;
28     }
29 }
30 
IsVectorAccess(TIntermBinary & binary)31 bool IsVectorAccess(TIntermBinary &binary)
32 {
33     TOperator op = binary.getOp();
34     switch (op)
35     {
36         case TOperator::EOpIndexDirect:
37         case TOperator::EOpIndexIndirect:
38             break;
39 
40         default:
41             return false;
42     }
43 
44     const TType &leftType = binary.getLeft()->getType();
45     if (!leftType.isVector() || leftType.isArray())
46     {
47         return false;
48     }
49 
50     ASSERT(IsScalarBasicType(binary.getType()));
51 
52     return true;
53 }
54 
IsVectorAccess(TIntermNode & node)55 bool IsVectorAccess(TIntermNode &node)
56 {
57     if (auto *bin = node.getAsBinaryNode())
58     {
59         return IsVectorAccess(*bin);
60     }
61     return false;
62 }
63 
64 // Differs from IsAssignment in that it does not include (++) or (--).
IsAssignEqualsSign(TOperator op)65 bool IsAssignEqualsSign(TOperator op)
66 {
67     switch (op)
68     {
69         case TOperator::EOpAssign:
70         case TOperator::EOpInitialize:
71         case TOperator::EOpAddAssign:
72         case TOperator::EOpSubAssign:
73         case TOperator::EOpMulAssign:
74         case TOperator::EOpVectorTimesMatrixAssign:
75         case TOperator::EOpVectorTimesScalarAssign:
76         case TOperator::EOpMatrixTimesScalarAssign:
77         case TOperator::EOpMatrixTimesMatrixAssign:
78         case TOperator::EOpDivAssign:
79         case TOperator::EOpIModAssign:
80         case TOperator::EOpBitShiftLeftAssign:
81         case TOperator::EOpBitShiftRightAssign:
82         case TOperator::EOpBitwiseAndAssign:
83         case TOperator::EOpBitwiseXorAssign:
84         case TOperator::EOpBitwiseOrAssign:
85             return true;
86 
87         default:
88             return false;
89     }
90 }
91 
92 // Only includes ($=) style assigns, where ($) is a binary op.
IsCompoundAssign(TOperator op)93 bool IsCompoundAssign(TOperator op)
94 {
95     switch (op)
96     {
97         case TOperator::EOpAddAssign:
98         case TOperator::EOpSubAssign:
99         case TOperator::EOpMulAssign:
100         case TOperator::EOpVectorTimesMatrixAssign:
101         case TOperator::EOpVectorTimesScalarAssign:
102         case TOperator::EOpMatrixTimesScalarAssign:
103         case TOperator::EOpMatrixTimesMatrixAssign:
104         case TOperator::EOpDivAssign:
105         case TOperator::EOpIModAssign:
106         case TOperator::EOpBitShiftLeftAssign:
107         case TOperator::EOpBitShiftRightAssign:
108         case TOperator::EOpBitwiseAndAssign:
109         case TOperator::EOpBitwiseXorAssign:
110         case TOperator::EOpBitwiseOrAssign:
111             return true;
112 
113         default:
114             return false;
115     }
116 }
117 
ReturnsReference(TOperator op)118 bool ReturnsReference(TOperator op)
119 {
120     switch (op)
121     {
122         case TOperator::EOpAssign:
123         case TOperator::EOpInitialize:
124         case TOperator::EOpAddAssign:
125         case TOperator::EOpSubAssign:
126         case TOperator::EOpMulAssign:
127         case TOperator::EOpVectorTimesMatrixAssign:
128         case TOperator::EOpVectorTimesScalarAssign:
129         case TOperator::EOpMatrixTimesScalarAssign:
130         case TOperator::EOpMatrixTimesMatrixAssign:
131         case TOperator::EOpDivAssign:
132         case TOperator::EOpIModAssign:
133         case TOperator::EOpBitShiftLeftAssign:
134         case TOperator::EOpBitShiftRightAssign:
135         case TOperator::EOpBitwiseAndAssign:
136         case TOperator::EOpBitwiseXorAssign:
137         case TOperator::EOpBitwiseOrAssign:
138 
139         case TOperator::EOpPostIncrement:
140         case TOperator::EOpPostDecrement:
141         case TOperator::EOpPreIncrement:
142         case TOperator::EOpPreDecrement:
143 
144         case TOperator::EOpIndexDirect:
145         case TOperator::EOpIndexIndirect:
146         case TOperator::EOpIndexDirectStruct:
147         case TOperator::EOpIndexDirectInterfaceBlock:
148 
149             return true;
150 
151         default:
152             return false;
153     }
154 }
155 
DecomposeCompoundAssignment(TIntermBinary & node)156 TIntermTyped &DecomposeCompoundAssignment(TIntermBinary &node)
157 {
158     TOperator op = node.getOp();
159     switch (op)
160     {
161         case TOperator::EOpAddAssign:
162             op = TOperator::EOpAdd;
163             break;
164         case TOperator::EOpSubAssign:
165             op = TOperator::EOpSub;
166             break;
167         case TOperator::EOpMulAssign:
168             op = TOperator::EOpMul;
169             break;
170         case TOperator::EOpVectorTimesMatrixAssign:
171             op = TOperator::EOpVectorTimesMatrix;
172             break;
173         case TOperator::EOpVectorTimesScalarAssign:
174             op = TOperator::EOpVectorTimesScalar;
175             break;
176         case TOperator::EOpMatrixTimesScalarAssign:
177             op = TOperator::EOpMatrixTimesScalar;
178             break;
179         case TOperator::EOpMatrixTimesMatrixAssign:
180             op = TOperator::EOpMatrixTimesMatrix;
181             break;
182         case TOperator::EOpDivAssign:
183             op = TOperator::EOpDiv;
184             break;
185         case TOperator::EOpIModAssign:
186             op = TOperator::EOpIMod;
187             break;
188         case TOperator::EOpBitShiftLeftAssign:
189             op = TOperator::EOpBitShiftLeft;
190             break;
191         case TOperator::EOpBitShiftRightAssign:
192             op = TOperator::EOpBitShiftRight;
193             break;
194         case TOperator::EOpBitwiseAndAssign:
195             op = TOperator::EOpBitwiseAnd;
196             break;
197         case TOperator::EOpBitwiseXorAssign:
198             op = TOperator::EOpBitwiseXor;
199             break;
200         case TOperator::EOpBitwiseOrAssign:
201             op = TOperator::EOpBitwiseOr;
202             break;
203         default:
204             UNREACHABLE();
205     }
206 
207     // This assumes SeparateCompoundExpressions has already been called.
208     // This assumption allows this code to not need to introduce temporaries.
209     //
210     // e.g. dont have to worry about:
211     //      vec[hasSideEffect()] *= 4
212     // becoming
213     //      vec[hasSideEffect()] = vec[hasSideEffect()] * 4
214 
215     TIntermTyped *left  = node.getLeft();
216     TIntermTyped *right = node.getRight();
217     return *new TIntermBinary(TOperator::EOpAssign, left->deepCopy(),
218                               new TIntermBinary(op, left, right));
219 }
220 
221 class Rewriter1 : public TIntermRebuild
222 {
223   public:
Rewriter1(TCompiler & compiler)224     Rewriter1(TCompiler &compiler) : TIntermRebuild(compiler, false, true) {}
225 
visitBinaryPost(TIntermBinary & binaryNode)226     PostResult visitBinaryPost(TIntermBinary &binaryNode) override
227     {
228         const TOperator op = binaryNode.getOp();
229         if (IsCompoundAssign(op))
230         {
231             TIntermTyped &left = *binaryNode.getLeft();
232             if (left.getAsSwizzleNode() || IsVectorAccess(left))
233             {
234                 return DecomposeCompoundAssignment(binaryNode);
235             }
236         }
237         return binaryNode;
238     }
239 };
240 
241 class Rewriter2 : public TIntermRebuild
242 {
243     std::vector<bool> mRequiresAddressingStack;
244     SymbolEnv &mSymbolEnv;
245 
246   private:
requiresAddressing() const247     bool requiresAddressing() const
248     {
249         if (mRequiresAddressingStack.empty())
250         {
251             return false;
252         }
253         return mRequiresAddressingStack.back();
254     }
255 
256   public:
~Rewriter2()257     ~Rewriter2() override { ASSERT(mRequiresAddressingStack.empty()); }
258 
Rewriter2(TCompiler & compiler,SymbolEnv & symbolEnv)259     Rewriter2(TCompiler &compiler, SymbolEnv &symbolEnv)
260         : TIntermRebuild(compiler, true, true), mSymbolEnv(symbolEnv)
261     {}
262 
visitAggregatePre(TIntermAggregate & aggregateNode)263     PreResult visitAggregatePre(TIntermAggregate &aggregateNode) override
264     {
265         const TFunction *func = aggregateNode.getFunction();
266         if (!func)
267         {
268             return aggregateNode;
269         }
270 
271         TIntermSequence &args = *aggregateNode.getSequence();
272         size_t argCount       = args.size();
273 
274         for (size_t i = 0; i < argCount; ++i)
275         {
276             const TVariable &param = *func->getParam(i);
277             const TType &paramType = param.getType();
278             TIntermNode *arg       = args[i];
279             ASSERT(arg);
280 
281             mRequiresAddressingStack.push_back(IsOutParam(paramType));
282             args[i] = rebuild(*arg).single();
283             ASSERT(args[i]);
284             ASSERT(!mRequiresAddressingStack.empty());
285             mRequiresAddressingStack.pop_back();
286         }
287 
288         return {aggregateNode, VisitBits::Neither};
289     }
290 
visitSwizzlePost(TIntermSwizzle & swizzleNode)291     PostResult visitSwizzlePost(TIntermSwizzle &swizzleNode) override
292     {
293         if (!requiresAddressing())
294         {
295             return swizzleNode;
296         }
297 
298         TIntermTyped &vecNode         = *swizzleNode.getOperand();
299         const TQualifierList &offsets = swizzleNode.getSwizzleOffsets();
300         ASSERT(!offsets.empty());
301         ASSERT(offsets.size() <= 4);
302 
303         auto &args = *new TIntermSequence();
304         args.reserve(offsets.size() + 1);
305         args.push_back(&vecNode);
306         for (int offset : offsets)
307         {
308             args.push_back(new TIntermConstantUnion(new TConstantUnion(offset),
309                                                     *new TType(TBasicType::EbtInt)));
310         }
311 
312         return mSymbolEnv.callFunctionOverload(Name("swizzle_ref"), swizzleNode.getType(), args);
313     }
314 
visitBinaryPre(TIntermBinary & binaryNode)315     PreResult visitBinaryPre(TIntermBinary &binaryNode) override
316     {
317         const TOperator op = binaryNode.getOp();
318 
319         const bool isAccess = IsVectorAccess(binaryNode);
320 
321         const bool disableTop   = !ReturnsReference(op) || !requiresAddressing();
322         const bool disableLeft  = disableTop;
323         const bool disableRight = disableTop || isAccess || IsAssignEqualsSign(op);
324 
325         auto traverse = [&](TIntermTyped &node, const bool disable) -> TIntermTyped & {
326             if (disable)
327             {
328                 mRequiresAddressingStack.push_back(false);
329             }
330             auto *newNode = asNode<TIntermTyped>(rebuild(node).single());
331             ASSERT(newNode);
332             if (disable)
333             {
334                 mRequiresAddressingStack.pop_back();
335             }
336             return *newNode;
337         };
338 
339         TIntermTyped &leftNode  = *binaryNode.getLeft();
340         TIntermTyped &rightNode = *binaryNode.getRight();
341 
342         TIntermTyped &newLeft  = traverse(leftNode, disableLeft);
343         TIntermTyped &newRight = traverse(rightNode, disableRight);
344 
345         if (!isAccess || disableTop)
346         {
347             if (&leftNode == &newLeft && &rightNode == &newRight)
348             {
349                 return {&binaryNode, VisitBits::Neither};
350             }
351             return {*new TIntermBinary(op, &newLeft, &newRight), VisitBits::Neither};
352         }
353 
354         return {mSymbolEnv.callFunctionOverload(Name("elem_ref"), binaryNode.getType(),
355                                                 *new TIntermSequence{&newLeft, &newRight}),
356                 VisitBits::Neither};
357     }
358 };
359 
360 }  // anonymous namespace
361 
RewriteUnaddressableReferences(TCompiler & compiler,TIntermBlock & root,SymbolEnv & symbolEnv)362 bool sh::RewriteUnaddressableReferences(TCompiler &compiler,
363                                         TIntermBlock &root,
364                                         SymbolEnv &symbolEnv)
365 {
366     if (!Rewriter1(compiler).rebuildRoot(root))
367     {
368         return false;
369     }
370     if (!Rewriter2(compiler, symbolEnv).rebuildRoot(root))
371     {
372         return false;
373     }
374     return true;
375 }
376