1 //
2 // Copyright 2019 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteAtomicCounters: Emulate atomic counter buffers with storage buffers.
7 //
8
9 #include "compiler/translator/tree_ops/RewriteAtomicCounters.h"
10
11 #include "compiler/translator/Compiler.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/SymbolTable.h"
14 #include "compiler/translator/tree_util/IntermNode_util.h"
15 #include "compiler/translator/tree_util/IntermTraverse.h"
16 #include "compiler/translator/tree_util/ReplaceVariable.h"
17
18 namespace sh
19 {
20 namespace
21 {
22 constexpr ImmutableString kAtomicCountersVarName = ImmutableString("atomicCounters");
23 constexpr ImmutableString kAtomicCountersBlockName = ImmutableString("ANGLEAtomicCounters");
24 constexpr ImmutableString kAtomicCounterFieldName = ImmutableString("counters");
25
26 // DeclareAtomicCountersBuffer adds a storage buffer array that's used with atomic counters.
DeclareAtomicCountersBuffers(TIntermBlock * root,TSymbolTable * symbolTable)27 const TVariable *DeclareAtomicCountersBuffers(TIntermBlock *root, TSymbolTable *symbolTable)
28 {
29 // Define `uint counters[];` as the only field in the interface block.
30 TFieldList *fieldList = new TFieldList;
31 TType *counterType = new TType(EbtUInt, EbpHigh, EvqGlobal);
32 counterType->makeArray(0);
33
34 TField *countersField =
35 new TField(counterType, kAtomicCounterFieldName, TSourceLoc(), SymbolType::AngleInternal);
36
37 fieldList->push_back(countersField);
38
39 TMemoryQualifier coherentMemory = TMemoryQualifier::Create();
40 coherentMemory.coherent = true;
41
42 // There are a maximum of 8 atomic counter buffers per IMPLEMENTATION_MAX_ATOMIC_COUNTER_BUFFERS
43 // in libANGLE/Constants.h.
44 constexpr uint32_t kMaxAtomicCounterBuffers = 8;
45
46 // Define a storage block "ANGLEAtomicCounters" with instance name "atomicCounters".
47 TLayoutQualifier layoutQualifier = TLayoutQualifier::Create();
48 layoutQualifier.blockStorage = EbsStd430;
49
50 return DeclareInterfaceBlock(root, symbolTable, fieldList, EvqBuffer, layoutQualifier,
51 coherentMemory, kMaxAtomicCounterBuffers, kAtomicCountersBlockName,
52 kAtomicCountersVarName);
53 }
54
CreateUniformBufferOffset(const TIntermTyped * uniformBufferOffsets,int binding)55 TIntermTyped *CreateUniformBufferOffset(const TIntermTyped *uniformBufferOffsets, int binding)
56 {
57 // Each uint in the |acbBufferOffsets| uniform contains offsets for 4 bindings. Therefore, the
58 // expression to get the uniform offset for the binding is:
59 //
60 // acbBufferOffsets[binding / 4] >> ((binding % 4) * 8) & 0xFF
61
62 // acbBufferOffsets[binding / 4]
63 TIntermBinary *uniformBufferOffsetUint = new TIntermBinary(
64 EOpIndexDirect, uniformBufferOffsets->deepCopy(), CreateIndexNode(binding / 4));
65
66 // acbBufferOffsets[binding / 4] >> ((binding % 4) * 8)
67 TIntermBinary *uniformBufferOffsetShifted = uniformBufferOffsetUint;
68 if (binding % 4 != 0)
69 {
70 uniformBufferOffsetShifted = new TIntermBinary(EOpBitShiftRight, uniformBufferOffsetUint,
71 CreateUIntNode((binding % 4) * 8));
72 }
73
74 // acbBufferOffsets[binding / 4] >> ((binding % 4) * 8) & 0xFF
75 return new TIntermBinary(EOpBitwiseAnd, uniformBufferOffsetShifted, CreateUIntNode(0xFF));
76 }
77
CreateAtomicCounterRef(TIntermTyped * atomicCounterExpression,const TVariable * atomicCounters,const TIntermTyped * uniformBufferOffsets)78 TIntermBinary *CreateAtomicCounterRef(TIntermTyped *atomicCounterExpression,
79 const TVariable *atomicCounters,
80 const TIntermTyped *uniformBufferOffsets)
81 {
82 // The atomic counters storage buffer declaration looks as such:
83 //
84 // layout(...) buffer ANGLEAtomicCounters
85 // {
86 // uint counters[];
87 // } atomicCounters[N];
88 //
89 // Where N is large enough to accommodate atomic counter buffer bindings used in the shader.
90 //
91 // This function takes an expression that uses an atomic counter, which can either be:
92 //
93 // - ac
94 // - acArray[index]
95 //
96 // Note that RewriteArrayOfArrayOfOpaqueUniforms has already flattened array of array of atomic
97 // counters.
98 //
99 // For the first case (ac), the following code is generated:
100 //
101 // atomicCounters[binding].counters[offset]
102 //
103 // For the second case (acArray[index]), the following code is generated:
104 //
105 // atomicCounters[binding].counters[offset + index]
106 //
107 // In either case, an offset given through uniforms is also added to |offset|. The binding is
108 // necessarily a constant thanks to MonomorphizeUnsupportedFunctions.
109
110 // First determine if there's an index, and extract the atomic counter symbol out of the
111 // expression.
112 TIntermSymbol *atomicCounterSymbol = atomicCounterExpression->getAsSymbolNode();
113 TIntermTyped *atomicCounterIndex = nullptr;
114 int atomicCounterConstIndex = 0;
115 TIntermBinary *asBinary = atomicCounterExpression->getAsBinaryNode();
116 if (asBinary != nullptr)
117 {
118 atomicCounterSymbol = asBinary->getLeft()->getAsSymbolNode();
119
120 switch (asBinary->getOp())
121 {
122 case EOpIndexDirect:
123 atomicCounterConstIndex = asBinary->getRight()->getAsConstantUnion()->getIConst(0);
124 break;
125 case EOpIndexIndirect:
126 atomicCounterIndex = asBinary->getRight();
127 break;
128 default:
129 UNREACHABLE();
130 }
131 }
132
133 // Extract binding and offset information out of the atomic counter symbol.
134 ASSERT(atomicCounterSymbol);
135 const TVariable *atomicCounterVar = &atomicCounterSymbol->variable();
136 const TType &atomicCounterType = atomicCounterVar->getType();
137
138 const int binding = atomicCounterType.getLayoutQualifier().binding;
139 int offset = atomicCounterType.getLayoutQualifier().offset / 4;
140
141 // Create the expression:
142 //
143 // offset + arrayIndex + uniformOffset
144 //
145 // If arrayIndex is a constant, it's added with offset right here.
146
147 offset += atomicCounterConstIndex;
148
149 TIntermTyped *index = CreateUniformBufferOffset(uniformBufferOffsets, binding);
150 if (atomicCounterIndex != nullptr)
151 {
152 index = new TIntermBinary(EOpAdd, index, atomicCounterIndex);
153 }
154 if (offset != 0)
155 {
156 index = new TIntermBinary(EOpAdd, index, CreateIndexNode(offset));
157 }
158
159 // Finally, create the complete expression:
160 //
161 // atomicCounters[binding].counters[index]
162
163 TIntermSymbol *atomicCountersRef = new TIntermSymbol(atomicCounters);
164
165 // atomicCounters[binding]
166 TIntermBinary *countersBlock =
167 new TIntermBinary(EOpIndexDirect, atomicCountersRef, CreateIndexNode(binding));
168
169 // atomicCounters[binding].counters
170 TIntermBinary *counters =
171 new TIntermBinary(EOpIndexDirectInterfaceBlock, countersBlock, CreateIndexNode(0));
172
173 return new TIntermBinary(EOpIndexIndirect, counters, index);
174 }
175
176 // Traverser that:
177 //
178 // 1. Removes the |uniform atomic_uint| declarations and remembers the binding and offset.
179 // 2. Substitutes |atomicVar[n]| with |buffer[binding].counters[offset + n]|.
180 class RewriteAtomicCountersTraverser : public TIntermTraverser
181 {
182 public:
RewriteAtomicCountersTraverser(TSymbolTable * symbolTable,const TVariable * atomicCounters,const TIntermTyped * acbBufferOffsets)183 RewriteAtomicCountersTraverser(TSymbolTable *symbolTable,
184 const TVariable *atomicCounters,
185 const TIntermTyped *acbBufferOffsets)
186 : TIntermTraverser(true, false, false, symbolTable),
187 mAtomicCounters(atomicCounters),
188 mAcbBufferOffsets(acbBufferOffsets)
189 {}
190
visitDeclaration(Visit visit,TIntermDeclaration * node)191 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
192 {
193 if (!mInGlobalScope)
194 {
195 return true;
196 }
197
198 const TIntermSequence &sequence = *(node->getSequence());
199
200 TIntermTyped *variable = sequence.front()->getAsTyped();
201 const TType &type = variable->getType();
202 bool isAtomicCounter = type.isAtomicCounter();
203
204 if (isAtomicCounter)
205 {
206 ASSERT(type.getQualifier() == EvqUniform);
207 TIntermSequence emptySequence;
208 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
209 std::move(emptySequence));
210
211 return false;
212 }
213
214 return true;
215 }
216
visitAggregate(Visit visit,TIntermAggregate * node)217 bool visitAggregate(Visit visit, TIntermAggregate *node) override
218 {
219 if (BuiltInGroup::IsBuiltIn(node->getOp()))
220 {
221 bool converted = convertBuiltinFunction(node);
222 return !converted;
223 }
224
225 // AST functions don't require modification as atomic counter function parameters are
226 // removed by MonomorphizeUnsupportedFunctions.
227 return true;
228 }
229
visitSymbol(TIntermSymbol * symbol)230 void visitSymbol(TIntermSymbol *symbol) override
231 {
232 // Cannot encounter the atomic counter symbol directly. It can only be used with functions,
233 // and therefore it's handled by visitAggregate.
234 ASSERT(!symbol->getType().isAtomicCounter());
235 }
236
visitBinary(Visit visit,TIntermBinary * node)237 bool visitBinary(Visit visit, TIntermBinary *node) override
238 {
239 // Cannot encounter an atomic counter expression directly. It can only be used with
240 // functions, and therefore it's handled by visitAggregate.
241 ASSERT(!node->getType().isAtomicCounter());
242 return true;
243 }
244
245 private:
convertBuiltinFunction(TIntermAggregate * node)246 bool convertBuiltinFunction(TIntermAggregate *node)
247 {
248 const TOperator op = node->getOp();
249
250 // If the function is |memoryBarrierAtomicCounter|, simply replace it with
251 // |memoryBarrierBuffer|.
252 if (op == EOpMemoryBarrierAtomicCounter)
253 {
254 TIntermSequence emptySequence;
255 TIntermTyped *substituteCall = CreateBuiltInFunctionCallNode(
256 "memoryBarrierBuffer", &emptySequence, *mSymbolTable, 310);
257 queueReplacement(substituteCall, OriginalNode::IS_DROPPED);
258 return true;
259 }
260
261 // If it's an |atomicCounter*| function, replace the function with an |atomic*| equivalent.
262 if (!node->getFunction()->isAtomicCounterFunction())
263 {
264 return false;
265 }
266
267 // Note: atomicAdd(0) is used for atomic reads.
268 uint32_t valueChange = 0;
269 constexpr char kAtomicAddFunction[] = "atomicAdd";
270 bool isDecrement = false;
271
272 if (op == EOpAtomicCounterIncrement)
273 {
274 valueChange = 1;
275 }
276 else if (op == EOpAtomicCounterDecrement)
277 {
278 // uint values are required to wrap around, so 0xFFFFFFFFu is used as -1.
279 valueChange = std::numeric_limits<uint32_t>::max();
280 static_assert(static_cast<uint32_t>(-1) == std::numeric_limits<uint32_t>::max(),
281 "uint32_t max is not -1");
282
283 isDecrement = true;
284 }
285 else
286 {
287 ASSERT(op == EOpAtomicCounter);
288 }
289
290 TIntermTyped *param = (*node->getSequence())[0]->getAsTyped();
291
292 TIntermSequence substituteArguments;
293 substituteArguments.push_back(
294 CreateAtomicCounterRef(param, mAtomicCounters, mAcbBufferOffsets));
295 substituteArguments.push_back(CreateUIntNode(valueChange));
296
297 TIntermTyped *substituteCall = CreateBuiltInFunctionCallNode(
298 kAtomicAddFunction, &substituteArguments, *mSymbolTable, 310);
299
300 // Note that atomicCounterDecrement returns the *new* value instead of the prior value,
301 // unlike atomicAdd. So we need to do a -1 on the result as well.
302 if (isDecrement)
303 {
304 substituteCall = new TIntermBinary(EOpSub, substituteCall, CreateUIntNode(1));
305 }
306
307 queueReplacement(substituteCall, OriginalNode::IS_DROPPED);
308 return true;
309 }
310
311 const TVariable *mAtomicCounters;
312 const TIntermTyped *mAcbBufferOffsets;
313 };
314
315 } // anonymous namespace
316
RewriteAtomicCounters(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const TIntermTyped * acbBufferOffsets,const TVariable ** atomicCountersOut)317 bool RewriteAtomicCounters(TCompiler *compiler,
318 TIntermBlock *root,
319 TSymbolTable *symbolTable,
320 const TIntermTyped *acbBufferOffsets,
321 const TVariable **atomicCountersOut)
322 {
323 const TVariable *atomicCounters = DeclareAtomicCountersBuffers(root, symbolTable);
324 if (atomicCountersOut)
325 {
326 *atomicCountersOut = atomicCounters;
327 }
328
329 RewriteAtomicCountersTraverser traverser(symbolTable, atomicCounters, acbBufferOffsets);
330 root->traverse(&traverser);
331 return traverser.updateTree(compiler, root);
332 }
333 } // namespace sh
334