xref: /aosp_15_r20/external/angle/src/compiler/translator/tree_ops/RewriteStructSamplers.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2018 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteStructSamplers: Extract samplers from structs.
7 //
8 
9 #include "compiler/translator/tree_ops/RewriteStructSamplers.h"
10 
11 #include "common/hash_containers.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/SymbolTable.h"
14 #include "compiler/translator/tree_util/IntermNode_util.h"
15 #include "compiler/translator/tree_util/IntermTraverse.h"
16 
17 namespace sh
18 {
19 namespace
20 {
21 
22 // Used to map one structure type to another (one where the samplers are removed).
23 struct StructureData
24 {
25     // The structure this was replaced with.  If nullptr, it means the structure is removed (because
26     // it had all samplers).
27     //
28     // ParseContext reorders the samplers to the end of the struct, so the EOpIndexDirectStruct
29     // expressions that select non-sampler members don't have to change when they are moved out of
30     // the struct.
31     const TStructure *modified;
32 };
33 
34 using StructureMap        = angle::HashMap<const TStructure *, StructureData>;
35 using StructureUniformMap = angle::HashMap<const TVariable *, const TVariable *>;
36 using ExtractedSamplerMap = angle::HashMap<std::string, const TVariable *>;
37 
38 TIntermTyped *RewriteModifiedStructFieldSelectionExpression(
39     TCompiler *compiler,
40     TIntermBinary *node,
41     const StructureMap &structureMap,
42     const StructureUniformMap &structureUniformMap,
43     const ExtractedSamplerMap &extractedSamplers);
44 
RewriteExpressionVisitBinaryHelper(TCompiler * compiler,TIntermBinary * node,const StructureMap & structureMap,const StructureUniformMap & structureUniformMap,const ExtractedSamplerMap & extractedSamplers)45 TIntermTyped *RewriteExpressionVisitBinaryHelper(TCompiler *compiler,
46                                                  TIntermBinary *node,
47                                                  const StructureMap &structureMap,
48                                                  const StructureUniformMap &structureUniformMap,
49                                                  const ExtractedSamplerMap &extractedSamplers)
50 {
51     // Only interested in EOpIndexDirectStruct binary nodes.
52     if (node->getOp() != EOpIndexDirectStruct)
53     {
54         return nullptr;
55     }
56 
57     const TStructure *structure = node->getLeft()->getType().getStruct();
58     ASSERT(structure);
59 
60     // If the result of the index is not a sampler and the struct is not replaced, there's nothing
61     // to do.
62     if (!node->getType().isSampler() && structureMap.find(structure) == structureMap.end())
63     {
64         return nullptr;
65     }
66 
67     // Otherwise, replace the whole expression such that:
68     //
69     // - if sampler, it's indexed with whatever indices the parent structs were indexed with,
70     // - otherwise, the chain of field selections is rewritten by modifying the base uniform so all
71     //   the intermediate nodes would have the correct type (and therefore fields).
72     ASSERT(structureMap.find(structure) != structureMap.end());
73 
74     return RewriteModifiedStructFieldSelectionExpression(compiler, node, structureMap,
75                                                          structureUniformMap, extractedSamplers);
76 }
77 
78 // Given an expression, this traverser calculates a new expression where sampler-in-structs are
79 // replaced with their extracted ones, and field indices are adjusted for the rest of the fields.
80 // In particular, this is run on the right node of EOpIndexIndirect binary nodes, so that the
81 // expression in the index gets a chance to go through this transformation.
82 class RewriteExpressionTraverser final : public TIntermTraverser
83 {
84   public:
RewriteExpressionTraverser(TCompiler * compiler,const StructureMap & structureMap,const StructureUniformMap & structureUniformMap,const ExtractedSamplerMap & extractedSamplers)85     explicit RewriteExpressionTraverser(TCompiler *compiler,
86                                         const StructureMap &structureMap,
87                                         const StructureUniformMap &structureUniformMap,
88                                         const ExtractedSamplerMap &extractedSamplers)
89         : TIntermTraverser(true, false, false),
90           mCompiler(compiler),
91           mStructureMap(structureMap),
92           mStructureUniformMap(structureUniformMap),
93           mExtractedSamplers(extractedSamplers)
94     {}
95 
visitBinary(Visit visit,TIntermBinary * node)96     bool visitBinary(Visit visit, TIntermBinary *node) override
97     {
98         TIntermTyped *rewritten = RewriteExpressionVisitBinaryHelper(
99             mCompiler, node, mStructureMap, mStructureUniformMap, mExtractedSamplers);
100 
101         if (rewritten == nullptr)
102         {
103             return true;
104         }
105 
106         queueReplacement(rewritten, OriginalNode::IS_DROPPED);
107 
108         // Don't iterate as the expression is rewritten.
109         return false;
110     }
111 
visitSymbol(TIntermSymbol * node)112     void visitSymbol(TIntermSymbol *node) override
113     {
114         // It's impossible to reach here with a symbol that needs replacement.
115         // MonomorphizeUnsupportedFunctions makes sure that whole structs containing
116         // samplers are not passed to functions, so any instance of the struct uniform is
117         // necessarily indexed right away.  visitBinary should have already taken care of it.
118         ASSERT(mStructureUniformMap.find(&node->variable()) == mStructureUniformMap.end());
119     }
120 
121   private:
122     TCompiler *mCompiler;
123 
124     // See RewriteStructSamplersTraverser.
125     const StructureMap &mStructureMap;
126     const StructureUniformMap &mStructureUniformMap;
127     const ExtractedSamplerMap &mExtractedSamplers;
128 };
129 
130 // Rewrite the index of an EOpIndexIndirect expression.  The root can never need replacing, because
131 // it cannot be a sampler itself or of a struct type.
RewriteIndexExpression(TCompiler * compiler,TIntermTyped * expression,const StructureMap & structureMap,const StructureUniformMap & structureUniformMap,const ExtractedSamplerMap & extractedSamplers)132 void RewriteIndexExpression(TCompiler *compiler,
133                             TIntermTyped *expression,
134                             const StructureMap &structureMap,
135                             const StructureUniformMap &structureUniformMap,
136                             const ExtractedSamplerMap &extractedSamplers)
137 {
138     RewriteExpressionTraverser traverser(compiler, structureMap, structureUniformMap,
139                                          extractedSamplers);
140     expression->traverse(&traverser);
141     bool valid = traverser.updateTree(compiler, expression);
142     ASSERT(valid);
143 }
144 
145 // Given an expression such as the following:
146 //
147 //                                                    EOpIndexDirectStruct (sampler)
148 //                                                    /                  \
149 //                                               EOpIndex*           field index
150 //                                              /        \
151 //                                EOpIndexDirectStruct   index 2
152 //                                /                  \
153 //                           EOpIndex*           field index
154 //                          /        \
155 //            EOpIndexDirectStruct   index 1
156 //            /                  \
157 //     Uniform Struct           field index
158 //
159 // produces:
160 //
161 //                                EOpIndex*
162 //                                /      \
163 //                           EOpIndex*  index 2
164 //                          /        \
165 //                      sampler    index 1
166 //
167 // If the expression is not a sampler, it only replaces the struct with the modified one, while
168 // still processing the EOpIndexIndirect expressions (which may contain more structs to map).
RewriteModifiedStructFieldSelectionExpression(TCompiler * compiler,TIntermBinary * node,const StructureMap & structureMap,const StructureUniformMap & structureUniformMap,const ExtractedSamplerMap & extractedSamplers)169 TIntermTyped *RewriteModifiedStructFieldSelectionExpression(
170     TCompiler *compiler,
171     TIntermBinary *node,
172     const StructureMap &structureMap,
173     const StructureUniformMap &structureUniformMap,
174     const ExtractedSamplerMap &extractedSamplers)
175 {
176     ASSERT(node->getOp() == EOpIndexDirectStruct);
177 
178     const bool isSampler = node->getType().isSampler();
179 
180     TIntermSymbol *baseUniform = nullptr;
181     std::string samplerName;
182 
183     TVector<TIntermBinary *> indexNodeStack;
184 
185     // Iterate once and build the name of the sampler.
186     TIntermBinary *iter = node;
187     while (baseUniform == nullptr)
188     {
189         indexNodeStack.push_back(iter);
190         baseUniform = iter->getLeft()->getAsSymbolNode();
191 
192         if (isSampler)
193         {
194             if (iter->getOp() == EOpIndexDirectStruct)
195             {
196                 // When indexed into a struct, get the field name instead and construct the sampler
197                 // name.
198                 samplerName.insert(0, iter->getIndexStructFieldName().data());
199                 samplerName.insert(0, "_");
200             }
201 
202             if (baseUniform)
203             {
204                 // If left is a symbol, we have reached the end of the chain.  Use the struct name
205                 // to finish building the name of the sampler.
206                 samplerName.insert(0, baseUniform->variable().name().data());
207             }
208         }
209 
210         iter = iter->getLeft()->getAsBinaryNode();
211     }
212 
213     TIntermTyped *rewritten = nullptr;
214 
215     if (isSampler)
216     {
217         ASSERT(extractedSamplers.find(samplerName) != extractedSamplers.end());
218         rewritten = new TIntermSymbol(extractedSamplers.at(samplerName));
219     }
220     else
221     {
222         const TVariable *baseUniformVar = &baseUniform->variable();
223         ASSERT(structureUniformMap.find(baseUniformVar) != structureUniformMap.end());
224         rewritten = new TIntermSymbol(structureUniformMap.at(baseUniformVar));
225     }
226 
227     // Iterate again and build the expression from bottom up.
228     for (auto it = indexNodeStack.rbegin(); it != indexNodeStack.rend(); ++it)
229     {
230         TIntermBinary *indexNode = *it;
231 
232         switch (indexNode->getOp())
233         {
234             case EOpIndexDirectStruct:
235                 if (!isSampler)
236                 {
237                     rewritten =
238                         new TIntermBinary(EOpIndexDirectStruct, rewritten, indexNode->getRight());
239                 }
240                 break;
241 
242             case EOpIndexDirect:
243                 rewritten = new TIntermBinary(EOpIndexDirect, rewritten, indexNode->getRight());
244                 break;
245 
246             case EOpIndexIndirect:
247             {
248                 // Run RewriteExpressionTraverser on the right node.  It may itself be an expression
249                 // with a sampler inside that needs to be rewritten, or simply use a field of a
250                 // struct that's remapped.
251                 TIntermTyped *indexExpression = indexNode->getRight();
252                 RewriteIndexExpression(compiler, indexExpression, structureMap, structureUniformMap,
253                                        extractedSamplers);
254                 rewritten = new TIntermBinary(EOpIndexIndirect, rewritten, indexExpression);
255                 break;
256             }
257 
258             default:
259                 UNREACHABLE();
260                 break;
261         }
262     }
263 
264     return rewritten;
265 }
266 
267 class RewriteStructSamplersTraverser final : public TIntermTraverser
268 {
269   public:
RewriteStructSamplersTraverser(TCompiler * compiler,TSymbolTable * symbolTable)270     explicit RewriteStructSamplersTraverser(TCompiler *compiler, TSymbolTable *symbolTable)
271         : TIntermTraverser(true, false, false, symbolTable),
272           mCompiler(compiler),
273           mRemovedUniformsCount(0)
274     {}
275 
removedUniformsCount() const276     int removedUniformsCount() const { return mRemovedUniformsCount; }
277 
278     // Each struct sampler declaration is stripped of its samplers. New uniforms are added for each
279     // stripped struct sampler.
visitDeclaration(Visit visit,TIntermDeclaration * decl)280     bool visitDeclaration(Visit visit, TIntermDeclaration *decl) override
281     {
282         if (!mInGlobalScope)
283         {
284             return true;
285         }
286 
287         const TIntermSequence &sequence = *(decl->getSequence());
288         TIntermTyped *declarator        = sequence.front()->getAsTyped();
289         const TType &type               = declarator->getType();
290 
291         if (!type.isStructureContainingSamplers())
292         {
293             return false;
294         }
295 
296         TIntermSequence newSequence;
297 
298         if (type.isStructSpecifier())
299         {
300             // If this is just a struct definition (not a uniform variable declaration of a
301             // struct type), just remove the samplers.  They are not instantiated yet.
302             const TStructure *structure = type.getStruct();
303             ASSERT(structure && mStructureMap.find(structure) == mStructureMap.end());
304 
305             stripStructSpecifierSamplers(structure, &newSequence);
306         }
307         else
308         {
309             const TStructure *structure = type.getStruct();
310 
311             // If the structure is defined at the same time, create the mapping to the stripped
312             // version first.
313             if (mStructureMap.find(structure) == mStructureMap.end())
314             {
315                 stripStructSpecifierSamplers(structure, &newSequence);
316             }
317 
318             // Then, extract the samplers from the struct and create global-scope variables instead.
319             TIntermSymbol *asSymbol = declarator->getAsSymbolNode();
320             ASSERT(asSymbol);
321             const TVariable &variable = asSymbol->variable();
322             ASSERT(variable.symbolType() != SymbolType::Empty);
323 
324             extractStructSamplerUniforms(variable, structure, &newSequence);
325         }
326 
327         mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), decl,
328                                         std::move(newSequence));
329 
330         return false;
331     }
332 
333     // Same implementation as in RewriteExpressionTraverser.  That traverser cannot replace root.
visitBinary(Visit visit,TIntermBinary * node)334     bool visitBinary(Visit visit, TIntermBinary *node) override
335     {
336         TIntermTyped *rewritten = RewriteExpressionVisitBinaryHelper(
337             mCompiler, node, mStructureMap, mStructureUniformMap, mExtractedSamplers);
338 
339         if (rewritten == nullptr)
340         {
341             return true;
342         }
343 
344         queueReplacement(rewritten, OriginalNode::IS_DROPPED);
345 
346         // Don't iterate as the expression is rewritten.
347         return false;
348     }
349 
350     // Same implementation as in RewriteExpressionTraverser.  That traverser cannot replace root.
visitSymbol(TIntermSymbol * node)351     void visitSymbol(TIntermSymbol *node) override
352     {
353         ASSERT(mStructureUniformMap.find(&node->variable()) == mStructureUniformMap.end());
354     }
355 
356   private:
357     // Removes all samplers from a struct specifier.
stripStructSpecifierSamplers(const TStructure * structure,TIntermSequence * newSequence)358     void stripStructSpecifierSamplers(const TStructure *structure, TIntermSequence *newSequence)
359     {
360         TFieldList *newFieldList = new TFieldList;
361         ASSERT(structure->containsSamplers());
362 
363         // Add this struct to the struct map
364         ASSERT(mStructureMap.find(structure) == mStructureMap.end());
365         StructureData *modifiedData = &mStructureMap[structure];
366 
367         modifiedData->modified = nullptr;
368 
369         for (size_t fieldIndex = 0; fieldIndex < structure->fields().size(); ++fieldIndex)
370         {
371             const TField *field    = structure->fields()[fieldIndex];
372             const TType &fieldType = *field->type();
373 
374             // If the field is a sampler, or a struct that's entirely removed, skip it.
375             if (!fieldType.isSampler() && !isRemovedStructType(fieldType))
376             {
377                 TType *newType = nullptr;
378 
379                 // Otherwise, if it's a struct that's replaced, create a new field of the replaced
380                 // type.
381                 if (fieldType.isStructureContainingSamplers())
382                 {
383                     const TStructure *fieldStruct = fieldType.getStruct();
384                     ASSERT(mStructureMap.find(fieldStruct) != mStructureMap.end());
385 
386                     const TStructure *modifiedStruct = mStructureMap[fieldStruct].modified;
387                     ASSERT(modifiedStruct);
388 
389                     newType = new TType(modifiedStruct, true);
390                     if (fieldType.isArray())
391                     {
392                         newType->makeArrays(fieldType.getArraySizes());
393                     }
394                 }
395                 else
396                 {
397                     // If not, duplicate the field as is.
398                     newType = new TType(fieldType);
399                 }
400 
401                 TField *newField =
402                     new TField(newType, field->name(), field->line(), field->symbolType());
403                 newFieldList->push_back(newField);
404             }
405         }
406 
407         // Prune empty structs.
408         if (newFieldList->empty())
409         {
410             return;
411         }
412 
413         // Declare a new struct with the same name and the new fields.
414         modifiedData->modified =
415             new TStructure(mSymbolTable,
416                            structure->symbolType() == SymbolType::Empty ? kEmptyImmutableString
417                                                                         : structure->name(),
418                            newFieldList, structure->symbolType());
419         TType *newStructType = new TType(modifiedData->modified, true);
420         TVariable *newStructVar =
421             new TVariable(mSymbolTable, kEmptyImmutableString, newStructType, SymbolType::Empty);
422         TIntermSymbol *newStructRef = new TIntermSymbol(newStructVar);
423 
424         TIntermDeclaration *structDecl = new TIntermDeclaration;
425         structDecl->appendDeclarator(newStructRef);
426 
427         newSequence->push_back(structDecl);
428     }
429 
430     // Returns true if the type is a struct that was removed because we extracted all the members.
isRemovedStructType(const TType & type) const431     bool isRemovedStructType(const TType &type) const
432     {
433         const TStructure *structure = type.getStruct();
434         if (structure == nullptr)
435         {
436             // Not a struct
437             return false;
438         }
439 
440         // A struct is removed if it is in the map, but doesn't have a replacement struct.
441         auto iter = mStructureMap.find(structure);
442         return iter != mStructureMap.end() && iter->second.modified == nullptr;
443     }
444 
445     // Removes samplers from struct uniforms. For each sampler removed also adds a new globally
446     // defined sampler uniform.
extractStructSamplerUniforms(const TVariable & variable,const TStructure * structure,TIntermSequence * newSequence)447     void extractStructSamplerUniforms(const TVariable &variable,
448                                       const TStructure *structure,
449                                       TIntermSequence *newSequence)
450     {
451         ASSERT(structure->containsSamplers());
452         ASSERT(mStructureMap.find(structure) != mStructureMap.end());
453 
454         const TType &type = variable.getType();
455         enterArray(type);
456 
457         for (const TField *field : structure->fields())
458         {
459             extractFieldSamplers(variable.name().data(), field, newSequence);
460         }
461 
462         // If there's a replacement structure (because there are non-sampler fields in the struct),
463         // add a declaration with that type.
464         const TStructure *modified = mStructureMap[structure].modified;
465         if (modified != nullptr)
466         {
467             TType *newType = new TType(modified, false);
468             if (type.isArray())
469             {
470                 newType->makeArrays(type.getArraySizes());
471             }
472             newType->setQualifier(EvqUniform);
473             const TVariable *newVariable =
474                 new TVariable(mSymbolTable, variable.name(), newType, variable.symbolType());
475 
476             TIntermDeclaration *newDecl = new TIntermDeclaration();
477             newDecl->appendDeclarator(new TIntermSymbol(newVariable));
478 
479             newSequence->push_back(newDecl);
480 
481             ASSERT(mStructureUniformMap.find(&variable) == mStructureUniformMap.end());
482             mStructureUniformMap[&variable] = newVariable;
483         }
484         else
485         {
486             mRemovedUniformsCount++;
487         }
488 
489         exitArray(type);
490     }
491 
492     // Extracts samplers from a field of a struct. Works with nested structs and arrays.
extractFieldSamplers(const std::string & prefix,const TField * field,TIntermSequence * newSequence)493     void extractFieldSamplers(const std::string &prefix,
494                               const TField *field,
495                               TIntermSequence *newSequence)
496     {
497         const TType &fieldType = *field->type();
498         if (fieldType.isSampler() || fieldType.isStructureContainingSamplers())
499         {
500             std::string newPrefix = prefix + "_" + field->name().data();
501 
502             if (fieldType.isSampler())
503             {
504                 extractSampler(newPrefix, fieldType, newSequence);
505             }
506             else
507             {
508                 enterArray(fieldType);
509                 const TStructure *structure = fieldType.getStruct();
510                 for (const TField *nestedField : structure->fields())
511                 {
512                     extractFieldSamplers(newPrefix, nestedField, newSequence);
513                 }
514                 exitArray(fieldType);
515             }
516         }
517     }
518 
GenerateArraySizesFromStack(TVector<unsigned int> * sizesOut)519     void GenerateArraySizesFromStack(TVector<unsigned int> *sizesOut)
520     {
521         sizesOut->reserve(mArraySizeStack.size());
522 
523         for (auto it = mArraySizeStack.rbegin(); it != mArraySizeStack.rend(); ++it)
524         {
525             sizesOut->push_back(*it);
526         }
527     }
528 
529     // Extracts a sampler from a struct. Declares the new extracted sampler.
extractSampler(const std::string & newName,const TType & fieldType,TIntermSequence * newSequence)530     void extractSampler(const std::string &newName,
531                         const TType &fieldType,
532                         TIntermSequence *newSequence)
533     {
534         ASSERT(fieldType.isSampler());
535 
536         TType *newType = new TType(fieldType);
537 
538         // Add array dimensions accumulated so far due to struct arrays.  Note that to support
539         // nested arrays, mArraySizeStack has the outermost size in the front.  |makeArrays| thus
540         // expects this in reverse order.
541         TVector<unsigned int> parentArraySizes;
542         GenerateArraySizesFromStack(&parentArraySizes);
543         newType->makeArrays(parentArraySizes);
544 
545         ImmutableStringBuilder nameBuilder(newName.size() + 1);
546         nameBuilder << newName;
547 
548         newType->setQualifier(EvqUniform);
549         TVariable *newVariable =
550             new TVariable(mSymbolTable, nameBuilder, newType, SymbolType::AngleInternal);
551         TIntermSymbol *newSymbol = new TIntermSymbol(newVariable);
552 
553         TIntermDeclaration *samplerDecl = new TIntermDeclaration;
554         samplerDecl->appendDeclarator(newSymbol);
555 
556         newSequence->push_back(samplerDecl);
557 
558         // TODO: Use a temp name instead of generating a name as currently done.  There is no
559         // guarantee that these generated names cannot clash.  Create a mapping from the previous
560         // name to the name assigned to the temp variable so ShaderVariable::mappedName can be
561         // updated post-transformation.  http://anglebug.com/42262930
562         ASSERT(mExtractedSamplers.find(newName) == mExtractedSamplers.end());
563         mExtractedSamplers[newName] = newVariable;
564     }
565 
enterArray(const TType & arrayType)566     void enterArray(const TType &arrayType)
567     {
568         const TSpan<const unsigned int> &arraySizes = arrayType.getArraySizes();
569         for (auto it = arraySizes.rbegin(); it != arraySizes.rend(); ++it)
570         {
571             unsigned int arraySize = *it;
572             mArraySizeStack.push_back(arraySize);
573         }
574     }
575 
exitArray(const TType & arrayType)576     void exitArray(const TType &arrayType)
577     {
578         mArraySizeStack.resize(mArraySizeStack.size() - arrayType.getNumArraySizes());
579     }
580 
581     TCompiler *mCompiler;
582     int mRemovedUniformsCount;
583 
584     // Map structures with samplers to ones that have their samplers removed.
585     StructureMap mStructureMap;
586 
587     // Map uniform variables of structure type that are replaced with another variable.
588     StructureUniformMap mStructureUniformMap;
589 
590     // Map a constructed sampler name to its variable.  Used to replace an expression that uses this
591     // sampler with the extracted one.
592     ExtractedSamplerMap mExtractedSamplers;
593 
594     // A stack of array sizes.  Used to figure out the array dimensions of the extracted sampler,
595     // for example when it's nested in an array of structs in an array of structs.
596     TVector<unsigned int> mArraySizeStack;
597 };
598 }  // anonymous namespace
599 
RewriteStructSamplers(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,int * removedUniformsCountOut)600 bool RewriteStructSamplers(TCompiler *compiler,
601                            TIntermBlock *root,
602                            TSymbolTable *symbolTable,
603                            int *removedUniformsCountOut)
604 {
605     RewriteStructSamplersTraverser traverser(compiler, symbolTable);
606     root->traverse(&traverser);
607     *removedUniformsCountOut = traverser.removedUniformsCount();
608     return traverser.updateTree(compiler, root);
609 }
610 }  // namespace sh
611