1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved. Use of this
3 // source code is governed by a BSD-style license that can be found in the
4 // LICENSE file.
5 //
6 // ReplaceArrayOfMatrixVarying: Find any references to array of matrices varying
7 // and replace it with array of vectors.
8 //
9
10 #include "compiler/translator/tree_util/ReplaceArrayOfMatrixVarying.h"
11
12 #include <vector>
13
14 #include "common/bitset_utils.h"
15 #include "common/debug.h"
16 #include "common/utilities.h"
17 #include "compiler/translator/Compiler.h"
18 #include "compiler/translator/SymbolTable.h"
19 #include "compiler/translator/tree_util/BuiltIn.h"
20 #include "compiler/translator/tree_util/FindMain.h"
21 #include "compiler/translator/tree_util/IntermNode_util.h"
22 #include "compiler/translator/tree_util/IntermTraverse.h"
23 #include "compiler/translator/tree_util/ReplaceVariable.h"
24 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
25 #include "compiler/translator/util.h"
26
27 namespace sh
28 {
29
30 // We create two variables to replace the given varying:
31 // - The new varying which is an array of vectors to be used at input/ouput only.
32 // - The new global variable which is a same type as given variable, to temporarily be used
33 // as replacements for assignments, arithmetic ops and so on. During input/ouput phrase, this temp
34 // variable will be copied from/to the array of vectors variable above.
35 // NOTE(hqle): Consider eliminating the need for using temp variable.
36
37 namespace
38 {
39 class CollectVaryingTraverser : public TIntermTraverser
40 {
41 public:
CollectVaryingTraverser(std::vector<const TVariable * > * varyingsOut)42 CollectVaryingTraverser(std::vector<const TVariable *> *varyingsOut)
43 : TIntermTraverser(true, false, false), mVaryingsOut(varyingsOut)
44 {}
45
visitDeclaration(Visit visit,TIntermDeclaration * node)46 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
47 {
48 const TIntermSequence &sequence = *(node->getSequence());
49
50 if (sequence.size() != 1)
51 {
52 return false;
53 }
54
55 TIntermTyped *variableType = sequence.front()->getAsTyped();
56 if (!variableType || !IsVarying(variableType->getQualifier()) ||
57 !variableType->isMatrix() || !variableType->isArray())
58 {
59 return false;
60 }
61
62 TIntermSymbol *variableSymbol = variableType->getAsSymbolNode();
63 if (!variableSymbol)
64 {
65 return false;
66 }
67
68 mVaryingsOut->push_back(&variableSymbol->variable());
69
70 return false;
71 }
72
73 private:
74 std::vector<const TVariable *> *mVaryingsOut;
75 };
76 } // namespace
77
ReplaceArrayOfMatrixVarying(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const TVariable * varying)78 [[nodiscard]] bool ReplaceArrayOfMatrixVarying(TCompiler *compiler,
79 TIntermBlock *root,
80 TSymbolTable *symbolTable,
81 const TVariable *varying)
82 {
83 const TType &type = varying->getType();
84
85 // Create global variable to temporarily acts as the given variable in places such as
86 // arithmetic, assignments an so on.
87 TType *tmpReplacementType = new TType(type);
88 tmpReplacementType->setQualifier(EvqGlobal);
89
90 TVariable *tempReplaceVar = new TVariable(
91 symbolTable, ImmutableString(std::string("ANGLE_AOM_Temp_") + varying->name().data()),
92 tmpReplacementType, SymbolType::AngleInternal);
93
94 if (!ReplaceVariable(compiler, root, varying, tempReplaceVar))
95 {
96 return false;
97 }
98
99 // Create array of vectors type
100 TType *varyingReplaceType = new TType(type);
101 varyingReplaceType->toMatrixColumnType();
102 varyingReplaceType->toArrayElementType();
103 varyingReplaceType->makeArray(type.getCols() * type.getOutermostArraySize());
104
105 TVariable *varyingReplaceVar =
106 new TVariable(symbolTable, varying->name(), varyingReplaceType, SymbolType::UserDefined);
107
108 TIntermSymbol *varyingReplaceDeclarator = new TIntermSymbol(varyingReplaceVar);
109 TIntermDeclaration *varyingReplaceDecl = new TIntermDeclaration;
110 varyingReplaceDecl->appendDeclarator(varyingReplaceDeclarator);
111 root->insertStatement(0, varyingReplaceDecl);
112
113 // Copy from/to the temp variable
114 TIntermBlock *reassignBlock = new TIntermBlock;
115 TIntermSymbol *tempReplaceSymbol = new TIntermSymbol(tempReplaceVar);
116 TIntermSymbol *varyingReplaceSymbol = new TIntermSymbol(varyingReplaceVar);
117 bool isInput = IsVaryingIn(type.getQualifier());
118
119 for (unsigned int i = 0; i < type.getOutermostArraySize(); ++i)
120 {
121 TIntermBinary *tempMatrixIndexed =
122 new TIntermBinary(EOpIndexDirect, tempReplaceSymbol->deepCopy(), CreateIndexNode(i));
123 for (uint8_t col = 0; col < type.getCols(); ++col)
124 {
125
126 TIntermBinary *tempMatrixColIndexed = new TIntermBinary(
127 EOpIndexDirect, tempMatrixIndexed->deepCopy(), CreateIndexNode(col));
128 TIntermBinary *vectorIndexed =
129 new TIntermBinary(EOpIndexDirect, varyingReplaceSymbol->deepCopy(),
130 CreateIndexNode(i * type.getCols() + col));
131 TIntermBinary *assignment;
132 if (isInput)
133 {
134 assignment = new TIntermBinary(EOpAssign, tempMatrixColIndexed, vectorIndexed);
135 }
136 else
137 {
138 assignment = new TIntermBinary(EOpAssign, vectorIndexed, tempMatrixColIndexed);
139 }
140 reassignBlock->appendStatement(assignment);
141 }
142 }
143
144 if (isInput)
145 {
146 TIntermFunctionDefinition *main = FindMain(root);
147 main->getBody()->insertStatement(0, reassignBlock);
148 return compiler->validateAST(root);
149 }
150 else
151 {
152 return RunAtTheEndOfShader(compiler, root, reassignBlock, symbolTable);
153 }
154 }
155
ReplaceArrayOfMatrixVaryings(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)156 [[nodiscard]] bool ReplaceArrayOfMatrixVaryings(TCompiler *compiler,
157 TIntermBlock *root,
158 TSymbolTable *symbolTable)
159 {
160 std::vector<const TVariable *> arrayOfMatrixVars;
161 CollectVaryingTraverser varCollector(&arrayOfMatrixVars);
162 root->traverse(&varCollector);
163
164 for (const TVariable *var : arrayOfMatrixVars)
165 {
166 if (!ReplaceArrayOfMatrixVarying(compiler, root, symbolTable, var))
167 {
168 return false;
169 }
170 }
171
172 return true;
173 }
174
175 } // namespace sh
176