1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include "compiler/translator/tree_ops/msl/RewriteOutArgs.h"
8 #include "compiler/translator/IntermRebuild.h"
9
10 using namespace sh;
11
12 namespace
13 {
14
15 template <typename T>
16 class SmallMultiSet
17 {
18 public:
19 struct Entry
20 {
21 T elem;
22 size_t count;
23 };
24
find(const T & x) const25 const Entry *find(const T &x) const
26 {
27 for (auto &entry : mEntries)
28 {
29 if (x == entry.elem)
30 {
31 return &entry;
32 }
33 }
34 return nullptr;
35 }
36
multiplicity(const T & x) const37 size_t multiplicity(const T &x) const
38 {
39 const Entry *entry = find(x);
40 return entry ? entry->count : 0;
41 }
42
insert(const T & x)43 const Entry &insert(const T &x)
44 {
45 Entry *entry = findMutable(x);
46 if (entry)
47 {
48 ++entry->count;
49 return *entry;
50 }
51 else
52 {
53 mEntries.push_back({x, 1});
54 return mEntries.back();
55 }
56 }
57
clear()58 void clear() { mEntries.clear(); }
59
empty() const60 bool empty() const { return mEntries.empty(); }
61
uniqueSize() const62 size_t uniqueSize() const { return mEntries.size(); }
63
64 private:
findMutable(const T & x)65 ANGLE_INLINE Entry *findMutable(const T &x) { return const_cast<Entry *>(find(x)); }
66
67 private:
68 std::vector<Entry> mEntries;
69 };
70
GetVariable(TIntermNode & node)71 const TVariable *GetVariable(TIntermNode &node)
72 {
73 TIntermTyped *tyNode = node.getAsTyped();
74 ASSERT(tyNode);
75 if (TIntermSymbol *symbol = tyNode->getAsSymbolNode())
76 {
77 return &symbol->variable();
78 }
79 return nullptr;
80 }
81
82 class Rewriter : public TIntermRebuild
83 {
84 SmallMultiSet<const TVariable *> mVarBuffer; // reusable buffer
85 SymbolEnv &mSymbolEnv;
86
87 public:
~Rewriter()88 ~Rewriter() override { ASSERT(mVarBuffer.empty()); }
89
Rewriter(TCompiler & compiler,SymbolEnv & symbolEnv)90 Rewriter(TCompiler &compiler, SymbolEnv &symbolEnv)
91 : TIntermRebuild(compiler, false, true), mSymbolEnv(symbolEnv)
92 {}
93
argAlreadyProcessed(TIntermTyped * arg)94 static bool argAlreadyProcessed(TIntermTyped *arg)
95 {
96 if (arg->getAsAggregate())
97 {
98 const TFunction *func = arg->getAsAggregate()->getFunction();
99 // These two builtins already generate references, and the
100 // ANGLE_inout and ANGLE_out overloads in ProgramPrelude are both
101 // unnecessary and incompatible.
102 if (func && func->symbolType() == SymbolType::AngleInternal &&
103 (func->name() == "swizzle_ref" || func->name() == "elem_ref"))
104 {
105 return true;
106 }
107 }
108 return false;
109 }
110
visitAggregatePost(TIntermAggregate & aggregateNode)111 PostResult visitAggregatePost(TIntermAggregate &aggregateNode) override
112 {
113 ASSERT(mVarBuffer.empty());
114
115 const TFunction *func = aggregateNode.getFunction();
116 if (!func)
117 {
118 return aggregateNode;
119 }
120
121 TIntermSequence &args = *aggregateNode.getSequence();
122 size_t argCount = args.size();
123
124 auto getParamQualifier = [&](size_t i) {
125 const TVariable ¶m = *func->getParam(i);
126 const TType ¶mType = param.getType();
127 const TQualifier paramQual = paramType.getQualifier();
128 return paramQual;
129 };
130
131 // Check which params might be aliased, and mark all out params as references.
132 bool mightAlias = false;
133 for (size_t i = 0; i < argCount; ++i)
134 {
135 const TQualifier paramQual = getParamQualifier(i);
136
137 switch (paramQual)
138 {
139 case TQualifier::EvqParamOut:
140 case TQualifier::EvqParamInOut:
141 {
142 const TVariable ¶m = *func->getParam(i);
143 if (!mSymbolEnv.isReference(param))
144 {
145 mSymbolEnv.markAsReference(param, AddressSpace::Thread);
146 }
147 // Note: not the same as param above, this refers to the variable in the
148 // argument list in the callsite.
149 const TVariable *var = GetVariable(*args[i]);
150 if (mVarBuffer.insert(var).count > 1)
151 {
152 mightAlias = true;
153 }
154 }
155 break;
156
157 default:
158 {
159 // If a function directly accesses global or stage output variables, the
160 // relevant internal struct is passed in as a parameter during translation.
161 // Ensure it is included in the aliasing checks.
162 const TVariable *var = GetVariable(*args[i]);
163 if (var != nullptr && var->symbolType() == SymbolType::AngleInternal)
164 {
165 // These names are set in Pipeline::getStructInstanceName
166 const ImmutableString &name = var->name();
167 if (name == "vertexOut" || name == "fragmentOut" ||
168 name == "nonConstGlobals")
169 {
170 mVarBuffer.insert(var);
171 }
172 }
173 }
174 break;
175 }
176 }
177
178 // Non-symbol (e.g., TIntermBinary) parameters are cached as null pointers.
179 const bool hasIndeterminateVar = mVarBuffer.find(nullptr);
180
181 if (!mightAlias)
182 {
183 // Support aliasing when there is only one unresolved parameter
184 // and at least one resolved parameter. This may happen in the
185 // following cases:
186 //
187 // - A struct member (or an array element) is passed along with the struct
188 //
189 // struct S { float f; };
190 // void foo(out S a, out float b) {...}
191 // void bar() {
192 // S s;
193 // foo(s, s.f);
194 // }
195 //
196 // mVarBuffer: s and nullptr (for s.f)
197 //
198 // - A global (or built-in) variable is passed as an out/inout
199 // parameter and also used in the called function directly
200 //
201 // float x;
202 // bool foo(out float a) {
203 // a = 2.0;
204 // return x == 1.0 && a == 2.0;
205 // }
206 // void bar() {
207 // x = 1.0;
208 // foo(x); // == true
209 // }
210 //
211 // In this case, foo and x will be translated to
212 //
213 // struct ANGLE_NonConstGlobals { float _ux; };
214 // bool _ufoo(thread ANGLE_NonConstGlobals & ANGLE_nonConstGlobals,
215 // thread float & _ua)
216 //
217 // mVarBuffer: nonConstGlobals and nullptr (for nonConstGlobals._ux)
218 mightAlias = hasIndeterminateVar && mVarBuffer.uniqueSize() > 1;
219 }
220
221 if (mightAlias)
222 {
223 for (size_t i = 0; i < argCount; ++i)
224 {
225 TIntermTyped *arg = args[i]->getAsTyped();
226 ASSERT(arg);
227 if (!argAlreadyProcessed(arg))
228 {
229 const TVariable *var = GetVariable(*arg);
230 const TQualifier paramQual = getParamQualifier(i);
231
232 if (hasIndeterminateVar || mVarBuffer.multiplicity(var) > 1)
233 {
234 switch (paramQual)
235 {
236 case TQualifier::EvqParamOut:
237 args[i] = &mSymbolEnv.callFunctionOverload(
238 Name("out"), arg->getType(), *new TIntermSequence{arg});
239 break;
240
241 case TQualifier::EvqParamInOut:
242 args[i] = &mSymbolEnv.callFunctionOverload(
243 Name("inout"), arg->getType(), *new TIntermSequence{arg});
244 break;
245
246 default:
247 break;
248 }
249 }
250 }
251 }
252 }
253
254 mVarBuffer.clear();
255
256 return aggregateNode;
257 }
258 };
259
260 } // anonymous namespace
261
RewriteOutArgs(TCompiler & compiler,TIntermBlock & root,SymbolEnv & symbolEnv)262 bool sh::RewriteOutArgs(TCompiler &compiler, TIntermBlock &root, SymbolEnv &symbolEnv)
263 {
264 Rewriter rewriter(compiler, symbolEnv);
265 if (!rewriter.rebuildRoot(root))
266 {
267 return false;
268 }
269 return true;
270 }
271