xref: /aosp_15_r20/external/angle/src/compiler/translator/tree_ops/msl/RewriteOutArgs.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/tree_ops/msl/RewriteOutArgs.h"
8 #include "compiler/translator/IntermRebuild.h"
9 
10 using namespace sh;
11 
12 namespace
13 {
14 
15 template <typename T>
16 class SmallMultiSet
17 {
18   public:
19     struct Entry
20     {
21         T elem;
22         size_t count;
23     };
24 
find(const T & x) const25     const Entry *find(const T &x) const
26     {
27         for (auto &entry : mEntries)
28         {
29             if (x == entry.elem)
30             {
31                 return &entry;
32             }
33         }
34         return nullptr;
35     }
36 
multiplicity(const T & x) const37     size_t multiplicity(const T &x) const
38     {
39         const Entry *entry = find(x);
40         return entry ? entry->count : 0;
41     }
42 
insert(const T & x)43     const Entry &insert(const T &x)
44     {
45         Entry *entry = findMutable(x);
46         if (entry)
47         {
48             ++entry->count;
49             return *entry;
50         }
51         else
52         {
53             mEntries.push_back({x, 1});
54             return mEntries.back();
55         }
56     }
57 
clear()58     void clear() { mEntries.clear(); }
59 
empty() const60     bool empty() const { return mEntries.empty(); }
61 
uniqueSize() const62     size_t uniqueSize() const { return mEntries.size(); }
63 
64   private:
findMutable(const T & x)65     ANGLE_INLINE Entry *findMutable(const T &x) { return const_cast<Entry *>(find(x)); }
66 
67   private:
68     std::vector<Entry> mEntries;
69 };
70 
GetVariable(TIntermNode & node)71 const TVariable *GetVariable(TIntermNode &node)
72 {
73     TIntermTyped *tyNode = node.getAsTyped();
74     ASSERT(tyNode);
75     if (TIntermSymbol *symbol = tyNode->getAsSymbolNode())
76     {
77         return &symbol->variable();
78     }
79     return nullptr;
80 }
81 
82 class Rewriter : public TIntermRebuild
83 {
84     SmallMultiSet<const TVariable *> mVarBuffer;  // reusable buffer
85     SymbolEnv &mSymbolEnv;
86 
87   public:
~Rewriter()88     ~Rewriter() override { ASSERT(mVarBuffer.empty()); }
89 
Rewriter(TCompiler & compiler,SymbolEnv & symbolEnv)90     Rewriter(TCompiler &compiler, SymbolEnv &symbolEnv)
91         : TIntermRebuild(compiler, false, true), mSymbolEnv(symbolEnv)
92     {}
93 
argAlreadyProcessed(TIntermTyped * arg)94     static bool argAlreadyProcessed(TIntermTyped *arg)
95     {
96         if (arg->getAsAggregate())
97         {
98             const TFunction *func = arg->getAsAggregate()->getFunction();
99             // These two builtins already generate references, and the
100             // ANGLE_inout and ANGLE_out overloads in ProgramPrelude are both
101             // unnecessary and incompatible.
102             if (func && func->symbolType() == SymbolType::AngleInternal &&
103                 (func->name() == "swizzle_ref" || func->name() == "elem_ref"))
104             {
105                 return true;
106             }
107         }
108         return false;
109     }
110 
visitAggregatePost(TIntermAggregate & aggregateNode)111     PostResult visitAggregatePost(TIntermAggregate &aggregateNode) override
112     {
113         ASSERT(mVarBuffer.empty());
114 
115         const TFunction *func = aggregateNode.getFunction();
116         if (!func)
117         {
118             return aggregateNode;
119         }
120 
121         TIntermSequence &args = *aggregateNode.getSequence();
122         size_t argCount       = args.size();
123 
124         auto getParamQualifier = [&](size_t i) {
125             const TVariable &param     = *func->getParam(i);
126             const TType &paramType     = param.getType();
127             const TQualifier paramQual = paramType.getQualifier();
128             return paramQual;
129         };
130 
131         // Check which params might be aliased, and mark all out params as references.
132         bool mightAlias = false;
133         for (size_t i = 0; i < argCount; ++i)
134         {
135             const TQualifier paramQual = getParamQualifier(i);
136 
137             switch (paramQual)
138             {
139                 case TQualifier::EvqParamOut:
140                 case TQualifier::EvqParamInOut:
141                 {
142                     const TVariable &param = *func->getParam(i);
143                     if (!mSymbolEnv.isReference(param))
144                     {
145                         mSymbolEnv.markAsReference(param, AddressSpace::Thread);
146                     }
147                     // Note: not the same as param above, this refers to the variable in the
148                     // argument list in the callsite.
149                     const TVariable *var = GetVariable(*args[i]);
150                     if (mVarBuffer.insert(var).count > 1)
151                     {
152                         mightAlias = true;
153                     }
154                 }
155                 break;
156 
157                 default:
158                 {
159                     // If a function directly accesses global or stage output variables, the
160                     // relevant internal struct is passed in as a parameter during translation.
161                     // Ensure it is included in the aliasing checks.
162                     const TVariable *var = GetVariable(*args[i]);
163                     if (var != nullptr && var->symbolType() == SymbolType::AngleInternal)
164                     {
165                         // These names are set in Pipeline::getStructInstanceName
166                         const ImmutableString &name = var->name();
167                         if (name == "vertexOut" || name == "fragmentOut" ||
168                             name == "nonConstGlobals")
169                         {
170                             mVarBuffer.insert(var);
171                         }
172                     }
173                 }
174                 break;
175             }
176         }
177 
178         // Non-symbol (e.g., TIntermBinary) parameters are cached as null pointers.
179         const bool hasIndeterminateVar = mVarBuffer.find(nullptr);
180 
181         if (!mightAlias)
182         {
183             // Support aliasing when there is only one unresolved parameter
184             // and at least one resolved parameter. This may happen in the
185             // following cases:
186             //
187             //   - A struct member (or an array element) is passed along with the struct
188             //
189             //         struct S { float f; };
190             //         void foo(out S a, out float b) {...}
191             //         void bar() {
192             //             S s;
193             //             foo(s, s.f);
194             //         }
195             //
196             //     mVarBuffer: s and nullptr (for s.f)
197             //
198             //   - A global (or built-in) variable is passed as an out/inout
199             //     parameter and also used in the called function directly
200             //
201             //         float x;
202             //         bool foo(out float a) {
203             //             a = 2.0;
204             //             return x == 1.0 && a == 2.0;
205             //         }
206             //         void bar() {
207             //             x = 1.0;
208             //             foo(x); // == true
209             //         }
210             //
211             //     In this case, foo and x will be translated to
212             //
213             //         struct ANGLE_NonConstGlobals { float _ux; };
214             //         bool _ufoo(thread ANGLE_NonConstGlobals & ANGLE_nonConstGlobals,
215             //                    thread float & _ua)
216             //
217             //     mVarBuffer: nonConstGlobals and nullptr (for nonConstGlobals._ux)
218             mightAlias = hasIndeterminateVar && mVarBuffer.uniqueSize() > 1;
219         }
220 
221         if (mightAlias)
222         {
223             for (size_t i = 0; i < argCount; ++i)
224             {
225                 TIntermTyped *arg = args[i]->getAsTyped();
226                 ASSERT(arg);
227                 if (!argAlreadyProcessed(arg))
228                 {
229                     const TVariable *var       = GetVariable(*arg);
230                     const TQualifier paramQual = getParamQualifier(i);
231 
232                     if (hasIndeterminateVar || mVarBuffer.multiplicity(var) > 1)
233                     {
234                         switch (paramQual)
235                         {
236                             case TQualifier::EvqParamOut:
237                                 args[i] = &mSymbolEnv.callFunctionOverload(
238                                     Name("out"), arg->getType(), *new TIntermSequence{arg});
239                                 break;
240 
241                             case TQualifier::EvqParamInOut:
242                                 args[i] = &mSymbolEnv.callFunctionOverload(
243                                     Name("inout"), arg->getType(), *new TIntermSequence{arg});
244                                 break;
245 
246                             default:
247                                 break;
248                         }
249                     }
250                 }
251             }
252         }
253 
254         mVarBuffer.clear();
255 
256         return aggregateNode;
257     }
258 };
259 
260 }  // anonymous namespace
261 
RewriteOutArgs(TCompiler & compiler,TIntermBlock & root,SymbolEnv & symbolEnv)262 bool sh::RewriteOutArgs(TCompiler &compiler, TIntermBlock &root, SymbolEnv &symbolEnv)
263 {
264     Rewriter rewriter(compiler, symbolEnv);
265     if (!rewriter.rebuildRoot(root))
266     {
267         return false;
268     }
269     return true;
270 }
271