xref: /aosp_15_r20/external/angle/src/compiler/translator/tree_ops/RewriteAtomicCounters.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2019 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteAtomicCounters: Emulate atomic counter buffers with storage buffers.
7 //
8 
9 #include "compiler/translator/tree_ops/RewriteAtomicCounters.h"
10 
11 #include "compiler/translator/Compiler.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/SymbolTable.h"
14 #include "compiler/translator/tree_util/IntermNode_util.h"
15 #include "compiler/translator/tree_util/IntermTraverse.h"
16 #include "compiler/translator/tree_util/ReplaceVariable.h"
17 
18 namespace sh
19 {
20 namespace
21 {
22 constexpr ImmutableString kAtomicCountersVarName   = ImmutableString("atomicCounters");
23 constexpr ImmutableString kAtomicCountersBlockName = ImmutableString("ANGLEAtomicCounters");
24 constexpr ImmutableString kAtomicCounterFieldName  = ImmutableString("counters");
25 
26 // DeclareAtomicCountersBuffer adds a storage buffer array that's used with atomic counters.
DeclareAtomicCountersBuffers(TIntermBlock * root,TSymbolTable * symbolTable)27 const TVariable *DeclareAtomicCountersBuffers(TIntermBlock *root, TSymbolTable *symbolTable)
28 {
29     // Define `uint counters[];` as the only field in the interface block.
30     TFieldList *fieldList = new TFieldList;
31     TType *counterType    = new TType(EbtUInt, EbpHigh, EvqGlobal);
32     counterType->makeArray(0);
33 
34     TField *countersField =
35         new TField(counterType, kAtomicCounterFieldName, TSourceLoc(), SymbolType::AngleInternal);
36 
37     fieldList->push_back(countersField);
38 
39     TMemoryQualifier coherentMemory = TMemoryQualifier::Create();
40     coherentMemory.coherent         = true;
41 
42     // There are a maximum of 8 atomic counter buffers per IMPLEMENTATION_MAX_ATOMIC_COUNTER_BUFFERS
43     // in libANGLE/Constants.h.
44     constexpr uint32_t kMaxAtomicCounterBuffers = 8;
45 
46     // Define a storage block "ANGLEAtomicCounters" with instance name "atomicCounters".
47     TLayoutQualifier layoutQualifier = TLayoutQualifier::Create();
48     layoutQualifier.blockStorage     = EbsStd430;
49 
50     return DeclareInterfaceBlock(root, symbolTable, fieldList, EvqBuffer, layoutQualifier,
51                                  coherentMemory, kMaxAtomicCounterBuffers, kAtomicCountersBlockName,
52                                  kAtomicCountersVarName);
53 }
54 
CreateUniformBufferOffset(const TIntermTyped * uniformBufferOffsets,int binding)55 TIntermTyped *CreateUniformBufferOffset(const TIntermTyped *uniformBufferOffsets, int binding)
56 {
57     // Each uint in the |acbBufferOffsets| uniform contains offsets for 4 bindings.  Therefore, the
58     // expression to get the uniform offset for the binding is:
59     //
60     //     acbBufferOffsets[binding / 4] >> ((binding % 4) * 8) & 0xFF
61 
62     // acbBufferOffsets[binding / 4]
63     TIntermBinary *uniformBufferOffsetUint = new TIntermBinary(
64         EOpIndexDirect, uniformBufferOffsets->deepCopy(), CreateIndexNode(binding / 4));
65 
66     // acbBufferOffsets[binding / 4] >> ((binding % 4) * 8)
67     TIntermBinary *uniformBufferOffsetShifted = uniformBufferOffsetUint;
68     if (binding % 4 != 0)
69     {
70         uniformBufferOffsetShifted = new TIntermBinary(EOpBitShiftRight, uniformBufferOffsetUint,
71                                                        CreateUIntNode((binding % 4) * 8));
72     }
73 
74     // acbBufferOffsets[binding / 4] >> ((binding % 4) * 8) & 0xFF
75     return new TIntermBinary(EOpBitwiseAnd, uniformBufferOffsetShifted, CreateUIntNode(0xFF));
76 }
77 
CreateAtomicCounterRef(TIntermTyped * atomicCounterExpression,const TVariable * atomicCounters,const TIntermTyped * uniformBufferOffsets)78 TIntermBinary *CreateAtomicCounterRef(TIntermTyped *atomicCounterExpression,
79                                       const TVariable *atomicCounters,
80                                       const TIntermTyped *uniformBufferOffsets)
81 {
82     // The atomic counters storage buffer declaration looks as such:
83     //
84     // layout(...) buffer ANGLEAtomicCounters
85     // {
86     //     uint counters[];
87     // } atomicCounters[N];
88     //
89     // Where N is large enough to accommodate atomic counter buffer bindings used in the shader.
90     //
91     // This function takes an expression that uses an atomic counter, which can either be:
92     //
93     //  - ac
94     //  - acArray[index]
95     //
96     // Note that RewriteArrayOfArrayOfOpaqueUniforms has already flattened array of array of atomic
97     // counters.
98     //
99     // For the first case (ac), the following code is generated:
100     //
101     //     atomicCounters[binding].counters[offset]
102     //
103     // For the second case (acArray[index]), the following code is generated:
104     //
105     //     atomicCounters[binding].counters[offset + index]
106     //
107     // In either case, an offset given through uniforms is also added to |offset|.  The binding is
108     // necessarily a constant thanks to MonomorphizeUnsupportedFunctions.
109 
110     // First determine if there's an index, and extract the atomic counter symbol out of the
111     // expression.
112     TIntermSymbol *atomicCounterSymbol = atomicCounterExpression->getAsSymbolNode();
113     TIntermTyped *atomicCounterIndex   = nullptr;
114     int atomicCounterConstIndex        = 0;
115     TIntermBinary *asBinary            = atomicCounterExpression->getAsBinaryNode();
116     if (asBinary != nullptr)
117     {
118         atomicCounterSymbol = asBinary->getLeft()->getAsSymbolNode();
119 
120         switch (asBinary->getOp())
121         {
122             case EOpIndexDirect:
123                 atomicCounterConstIndex = asBinary->getRight()->getAsConstantUnion()->getIConst(0);
124                 break;
125             case EOpIndexIndirect:
126                 atomicCounterIndex = asBinary->getRight();
127                 break;
128             default:
129                 UNREACHABLE();
130         }
131     }
132 
133     // Extract binding and offset information out of the atomic counter symbol.
134     ASSERT(atomicCounterSymbol);
135     const TVariable *atomicCounterVar = &atomicCounterSymbol->variable();
136     const TType &atomicCounterType    = atomicCounterVar->getType();
137 
138     const int binding = atomicCounterType.getLayoutQualifier().binding;
139     int offset        = atomicCounterType.getLayoutQualifier().offset / 4;
140 
141     // Create the expression:
142     //
143     //     offset + arrayIndex + uniformOffset
144     //
145     // If arrayIndex is a constant, it's added with offset right here.
146 
147     offset += atomicCounterConstIndex;
148 
149     TIntermTyped *index = CreateUniformBufferOffset(uniformBufferOffsets, binding);
150     if (atomicCounterIndex != nullptr)
151     {
152         index = new TIntermBinary(EOpAdd, index, atomicCounterIndex);
153     }
154     if (offset != 0)
155     {
156         index = new TIntermBinary(EOpAdd, index, CreateIndexNode(offset));
157     }
158 
159     // Finally, create the complete expression:
160     //
161     //     atomicCounters[binding].counters[index]
162 
163     TIntermSymbol *atomicCountersRef = new TIntermSymbol(atomicCounters);
164 
165     // atomicCounters[binding]
166     TIntermBinary *countersBlock =
167         new TIntermBinary(EOpIndexDirect, atomicCountersRef, CreateIndexNode(binding));
168 
169     // atomicCounters[binding].counters
170     TIntermBinary *counters =
171         new TIntermBinary(EOpIndexDirectInterfaceBlock, countersBlock, CreateIndexNode(0));
172 
173     return new TIntermBinary(EOpIndexIndirect, counters, index);
174 }
175 
176 // Traverser that:
177 //
178 // 1. Removes the |uniform atomic_uint| declarations and remembers the binding and offset.
179 // 2. Substitutes |atomicVar[n]| with |buffer[binding].counters[offset + n]|.
180 class RewriteAtomicCountersTraverser : public TIntermTraverser
181 {
182   public:
RewriteAtomicCountersTraverser(TSymbolTable * symbolTable,const TVariable * atomicCounters,const TIntermTyped * acbBufferOffsets)183     RewriteAtomicCountersTraverser(TSymbolTable *symbolTable,
184                                    const TVariable *atomicCounters,
185                                    const TIntermTyped *acbBufferOffsets)
186         : TIntermTraverser(true, false, false, symbolTable),
187           mAtomicCounters(atomicCounters),
188           mAcbBufferOffsets(acbBufferOffsets)
189     {}
190 
visitDeclaration(Visit visit,TIntermDeclaration * node)191     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
192     {
193         if (!mInGlobalScope)
194         {
195             return true;
196         }
197 
198         const TIntermSequence &sequence = *(node->getSequence());
199 
200         TIntermTyped *variable = sequence.front()->getAsTyped();
201         const TType &type      = variable->getType();
202         bool isAtomicCounter   = type.isAtomicCounter();
203 
204         if (isAtomicCounter)
205         {
206             ASSERT(type.getQualifier() == EvqUniform);
207             TIntermSequence emptySequence;
208             mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
209                                             std::move(emptySequence));
210 
211             return false;
212         }
213 
214         return true;
215     }
216 
visitAggregate(Visit visit,TIntermAggregate * node)217     bool visitAggregate(Visit visit, TIntermAggregate *node) override
218     {
219         if (BuiltInGroup::IsBuiltIn(node->getOp()))
220         {
221             bool converted = convertBuiltinFunction(node);
222             return !converted;
223         }
224 
225         // AST functions don't require modification as atomic counter function parameters are
226         // removed by MonomorphizeUnsupportedFunctions.
227         return true;
228     }
229 
visitSymbol(TIntermSymbol * symbol)230     void visitSymbol(TIntermSymbol *symbol) override
231     {
232         // Cannot encounter the atomic counter symbol directly.  It can only be used with functions,
233         // and therefore it's handled by visitAggregate.
234         ASSERT(!symbol->getType().isAtomicCounter());
235     }
236 
visitBinary(Visit visit,TIntermBinary * node)237     bool visitBinary(Visit visit, TIntermBinary *node) override
238     {
239         // Cannot encounter an atomic counter expression directly.  It can only be used with
240         // functions, and therefore it's handled by visitAggregate.
241         ASSERT(!node->getType().isAtomicCounter());
242         return true;
243     }
244 
245   private:
convertBuiltinFunction(TIntermAggregate * node)246     bool convertBuiltinFunction(TIntermAggregate *node)
247     {
248         const TOperator op = node->getOp();
249 
250         // If the function is |memoryBarrierAtomicCounter|, simply replace it with
251         // |memoryBarrierBuffer|.
252         if (op == EOpMemoryBarrierAtomicCounter)
253         {
254             TIntermSequence emptySequence;
255             TIntermTyped *substituteCall = CreateBuiltInFunctionCallNode(
256                 "memoryBarrierBuffer", &emptySequence, *mSymbolTable, 310);
257             queueReplacement(substituteCall, OriginalNode::IS_DROPPED);
258             return true;
259         }
260 
261         // If it's an |atomicCounter*| function, replace the function with an |atomic*| equivalent.
262         if (!node->getFunction()->isAtomicCounterFunction())
263         {
264             return false;
265         }
266 
267         // Note: atomicAdd(0) is used for atomic reads.
268         uint32_t valueChange                = 0;
269         constexpr char kAtomicAddFunction[] = "atomicAdd";
270         bool isDecrement                    = false;
271 
272         if (op == EOpAtomicCounterIncrement)
273         {
274             valueChange = 1;
275         }
276         else if (op == EOpAtomicCounterDecrement)
277         {
278             // uint values are required to wrap around, so 0xFFFFFFFFu is used as -1.
279             valueChange = std::numeric_limits<uint32_t>::max();
280             static_assert(static_cast<uint32_t>(-1) == std::numeric_limits<uint32_t>::max(),
281                           "uint32_t max is not -1");
282 
283             isDecrement = true;
284         }
285         else
286         {
287             ASSERT(op == EOpAtomicCounter);
288         }
289 
290         TIntermTyped *param = (*node->getSequence())[0]->getAsTyped();
291 
292         TIntermSequence substituteArguments;
293         substituteArguments.push_back(
294             CreateAtomicCounterRef(param, mAtomicCounters, mAcbBufferOffsets));
295         substituteArguments.push_back(CreateUIntNode(valueChange));
296 
297         TIntermTyped *substituteCall = CreateBuiltInFunctionCallNode(
298             kAtomicAddFunction, &substituteArguments, *mSymbolTable, 310);
299 
300         // Note that atomicCounterDecrement returns the *new* value instead of the prior value,
301         // unlike atomicAdd.  So we need to do a -1 on the result as well.
302         if (isDecrement)
303         {
304             substituteCall = new TIntermBinary(EOpSub, substituteCall, CreateUIntNode(1));
305         }
306 
307         queueReplacement(substituteCall, OriginalNode::IS_DROPPED);
308         return true;
309     }
310 
311     const TVariable *mAtomicCounters;
312     const TIntermTyped *mAcbBufferOffsets;
313 };
314 
315 }  // anonymous namespace
316 
RewriteAtomicCounters(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const TIntermTyped * acbBufferOffsets,const TVariable ** atomicCountersOut)317 bool RewriteAtomicCounters(TCompiler *compiler,
318                            TIntermBlock *root,
319                            TSymbolTable *symbolTable,
320                            const TIntermTyped *acbBufferOffsets,
321                            const TVariable **atomicCountersOut)
322 {
323     const TVariable *atomicCounters = DeclareAtomicCountersBuffers(root, symbolTable);
324     if (atomicCountersOut)
325     {
326         *atomicCountersOut = atomicCounters;
327     }
328 
329     RewriteAtomicCountersTraverser traverser(symbolTable, atomicCounters, acbBufferOffsets);
330     root->traverse(&traverser);
331     return traverser.updateTree(compiler, root);
332 }
333 }  // namespace sh
334