xref: /aosp_15_r20/external/angle/src/compiler/translator/tree_util/ReplaceArrayOfMatrixVarying.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved. Use of this
3 // source code is governed by a BSD-style license that can be found in the
4 // LICENSE file.
5 //
6 // ReplaceArrayOfMatrixVarying: Find any references to array of matrices varying
7 // and replace it with array of vectors.
8 //
9 
10 #include "compiler/translator/tree_util/ReplaceArrayOfMatrixVarying.h"
11 
12 #include <vector>
13 
14 #include "common/bitset_utils.h"
15 #include "common/debug.h"
16 #include "common/utilities.h"
17 #include "compiler/translator/Compiler.h"
18 #include "compiler/translator/SymbolTable.h"
19 #include "compiler/translator/tree_util/BuiltIn.h"
20 #include "compiler/translator/tree_util/FindMain.h"
21 #include "compiler/translator/tree_util/IntermNode_util.h"
22 #include "compiler/translator/tree_util/IntermTraverse.h"
23 #include "compiler/translator/tree_util/ReplaceVariable.h"
24 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
25 #include "compiler/translator/util.h"
26 
27 namespace sh
28 {
29 
30 // We create two variables to replace the given varying:
31 // - The new varying which is an array of vectors to be used at input/ouput only.
32 // - The new global variable which is a same type as given variable, to temporarily be used
33 // as replacements for assignments, arithmetic ops and so on. During input/ouput phrase, this temp
34 // variable will be copied from/to the array of vectors variable above.
35 // NOTE(hqle): Consider eliminating the need for using temp variable.
36 
37 namespace
38 {
39 class CollectVaryingTraverser : public TIntermTraverser
40 {
41   public:
CollectVaryingTraverser(std::vector<const TVariable * > * varyingsOut)42     CollectVaryingTraverser(std::vector<const TVariable *> *varyingsOut)
43         : TIntermTraverser(true, false, false), mVaryingsOut(varyingsOut)
44     {}
45 
visitDeclaration(Visit visit,TIntermDeclaration * node)46     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
47     {
48         const TIntermSequence &sequence = *(node->getSequence());
49 
50         if (sequence.size() != 1)
51         {
52             return false;
53         }
54 
55         TIntermTyped *variableType = sequence.front()->getAsTyped();
56         if (!variableType || !IsVarying(variableType->getQualifier()) ||
57             !variableType->isMatrix() || !variableType->isArray())
58         {
59             return false;
60         }
61 
62         TIntermSymbol *variableSymbol = variableType->getAsSymbolNode();
63         if (!variableSymbol)
64         {
65             return false;
66         }
67 
68         mVaryingsOut->push_back(&variableSymbol->variable());
69 
70         return false;
71     }
72 
73   private:
74     std::vector<const TVariable *> *mVaryingsOut;
75 };
76 }  // namespace
77 
ReplaceArrayOfMatrixVarying(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const TVariable * varying)78 [[nodiscard]] bool ReplaceArrayOfMatrixVarying(TCompiler *compiler,
79                                                TIntermBlock *root,
80                                                TSymbolTable *symbolTable,
81                                                const TVariable *varying)
82 {
83     const TType &type = varying->getType();
84 
85     // Create global variable to temporarily acts as the given variable in places such as
86     // arithmetic, assignments an so on.
87     TType *tmpReplacementType = new TType(type);
88     tmpReplacementType->setQualifier(EvqGlobal);
89 
90     TVariable *tempReplaceVar = new TVariable(
91         symbolTable, ImmutableString(std::string("ANGLE_AOM_Temp_") + varying->name().data()),
92         tmpReplacementType, SymbolType::AngleInternal);
93 
94     if (!ReplaceVariable(compiler, root, varying, tempReplaceVar))
95     {
96         return false;
97     }
98 
99     // Create array of vectors type
100     TType *varyingReplaceType = new TType(type);
101     varyingReplaceType->toMatrixColumnType();
102     varyingReplaceType->toArrayElementType();
103     varyingReplaceType->makeArray(type.getCols() * type.getOutermostArraySize());
104 
105     TVariable *varyingReplaceVar =
106         new TVariable(symbolTable, varying->name(), varyingReplaceType, SymbolType::UserDefined);
107 
108     TIntermSymbol *varyingReplaceDeclarator = new TIntermSymbol(varyingReplaceVar);
109     TIntermDeclaration *varyingReplaceDecl  = new TIntermDeclaration;
110     varyingReplaceDecl->appendDeclarator(varyingReplaceDeclarator);
111     root->insertStatement(0, varyingReplaceDecl);
112 
113     // Copy from/to the temp variable
114     TIntermBlock *reassignBlock         = new TIntermBlock;
115     TIntermSymbol *tempReplaceSymbol    = new TIntermSymbol(tempReplaceVar);
116     TIntermSymbol *varyingReplaceSymbol = new TIntermSymbol(varyingReplaceVar);
117     bool isInput                        = IsVaryingIn(type.getQualifier());
118 
119     for (unsigned int i = 0; i < type.getOutermostArraySize(); ++i)
120     {
121         TIntermBinary *tempMatrixIndexed =
122             new TIntermBinary(EOpIndexDirect, tempReplaceSymbol->deepCopy(), CreateIndexNode(i));
123         for (uint8_t col = 0; col < type.getCols(); ++col)
124         {
125 
126             TIntermBinary *tempMatrixColIndexed = new TIntermBinary(
127                 EOpIndexDirect, tempMatrixIndexed->deepCopy(), CreateIndexNode(col));
128             TIntermBinary *vectorIndexed =
129                 new TIntermBinary(EOpIndexDirect, varyingReplaceSymbol->deepCopy(),
130                                   CreateIndexNode(i * type.getCols() + col));
131             TIntermBinary *assignment;
132             if (isInput)
133             {
134                 assignment = new TIntermBinary(EOpAssign, tempMatrixColIndexed, vectorIndexed);
135             }
136             else
137             {
138                 assignment = new TIntermBinary(EOpAssign, vectorIndexed, tempMatrixColIndexed);
139             }
140             reassignBlock->appendStatement(assignment);
141         }
142     }
143 
144     if (isInput)
145     {
146         TIntermFunctionDefinition *main = FindMain(root);
147         main->getBody()->insertStatement(0, reassignBlock);
148         return compiler->validateAST(root);
149     }
150     else
151     {
152         return RunAtTheEndOfShader(compiler, root, reassignBlock, symbolTable);
153     }
154 }
155 
ReplaceArrayOfMatrixVaryings(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)156 [[nodiscard]] bool ReplaceArrayOfMatrixVaryings(TCompiler *compiler,
157                                                 TIntermBlock *root,
158                                                 TSymbolTable *symbolTable)
159 {
160     std::vector<const TVariable *> arrayOfMatrixVars;
161     CollectVaryingTraverser varCollector(&arrayOfMatrixVars);
162     root->traverse(&varCollector);
163 
164     for (const TVariable *var : arrayOfMatrixVars)
165     {
166         if (!ReplaceArrayOfMatrixVarying(compiler, root, symbolTable, var))
167         {
168             return false;
169         }
170     }
171 
172     return true;
173 }
174 
175 }  // namespace sh
176