1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Scalarize vector and matrix constructor args, so that vectors built from components don't have
7 // matrix arguments, and matrices built from components don't have vector arguments. This avoids
8 // driver bugs around vector and matrix constructors.
9 //
10 
11 #include "compiler/translator/tree_ops/glsl/ScalarizeVecAndMatConstructorArgs.h"
12 
13 #include "angle_gl.h"
14 #include "common/angleutils.h"
15 #include "compiler/translator/Compiler.h"
16 #include "compiler/translator/tree_util/IntermNode_util.h"
17 #include "compiler/translator/tree_util/IntermTraverse.h"
18 
19 namespace sh
20 {
21 
22 namespace
23 {
GetHelperType(const TType & type,TQualifier qualifier)24 const TType *GetHelperType(const TType &type, TQualifier qualifier)
25 {
26     // If the type does not have a precision, it means that non of the parameters of the constructor
27     // have precision (for example because they are constants, or bool), and there is any precision
28     // propagation happening from nearby operands.  In that case, assign a highp precision to them;
29     // the driver will inline and eliminate the call anyway, and the precision does not affect
30     // anything.
31     constexpr TPrecision kDefaultPrecision = EbpHigh;
32 
33     TType *newType = new TType(type.getBasicType(), type.getNominalSize(), type.getSecondarySize());
34     if (type.getBasicType() != EbtBool)
35     {
36         newType->setPrecision(type.getPrecision() != EbpUndefined ? type.getPrecision()
37                                                                   : kDefaultPrecision);
38     }
39     newType->setQualifier(qualifier);
40 
41     return newType;
42 }
43 
44 // Traverser that converts a vector or matrix constructor to one that only uses scalars.  To support
45 // all the various places such a constructor could be found, a helper function is created for each
46 // such constructor.  The helper function takes the constructor arguments and creates the object.
47 //
48 // Constructors that are transformed are:
49 //
50 // - vecN(scalar): translates to vecN(scalar, ..., scalar)
51 // - vecN(vec1, vec2, ...): translates to vecN(vec1.x, vec1.y, vec2.x, ...)
52 // - vecN(matrix): translates to vecN(matrix[0][0], matrix[0][1], ...)
53 // - matNxM(scalar): translates to matNxM(scalar, 0, ..., 0
54 //                                        0, scalar, ..., 0
55 //                                        ...
56 //                                        0, 0, ..., scalar)
57 // - matNxM(vec1, vec2, ...): translates to matNxM(vec1.x, vec1.y, vec2.x, ...)
58 // - matNxM(matrixAxB): translates to matNxM(matrix[0][0], matrix[0][1], ..., 0
59 //                                           matrix[1][0], matrix[1][1], ..., 0
60 //                                           ...
61 //                                           0,            0,            ..., 1)
62 //
63 class ScalarizeTraverser : public TIntermTraverser
64 {
65   public:
ScalarizeTraverser(TSymbolTable * symbolTable)66     ScalarizeTraverser(TSymbolTable *symbolTable)
67         : TIntermTraverser(true, false, false, symbolTable)
68     {}
69 
70     bool update(TCompiler *compiler, TIntermBlock *root);
71 
72   protected:
73     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
74 
75   private:
76     bool shouldScalarize(TIntermTyped *node);
77 
78     // Create a helper function that takes the same arguments as the constructor it replaces.
79     const TFunction *createHelper(TIntermAggregate *node);
80     TIntermTyped *createHelperCall(TIntermAggregate *node, const TFunction *helper);
81     void addHelperDefinition(const TFunction *helper, TIntermBlock *body);
82 
83     // If given a constructor, convert it to a function call.  Recursively processes constructor
84     // arguments.  Otherwise, recursively visit the node.
85     TIntermTyped *createConstructor(TIntermTyped *node);
86 
87     void extractComponents(const TFunction *helper,
88                            size_t componentCount,
89                            TIntermSequence *componentsOut);
90 
91     void createConstructorVectorFromScalar(TIntermAggregate *node,
92                                            const TFunction *helper,
93                                            TIntermSequence *constructorArgsOut);
94     void createConstructorVectorFromMultiple(TIntermAggregate *node,
95                                              const TFunction *helper,
96                                              TIntermSequence *constructorArgsOut);
97     void createConstructorMatrixFromScalar(TIntermAggregate *node,
98                                            const TFunction *helper,
99                                            TIntermSequence *constructorArgsOut);
100     void createConstructorMatrixFromVectors(TIntermAggregate *node,
101                                             const TFunction *helper,
102                                             TIntermSequence *constructorArgsOut);
103     void createConstructorMatrixFromMatrix(TIntermAggregate *node,
104                                            const TFunction *helper,
105                                            TIntermSequence *constructorArgsOut);
106 
107     TIntermSequence mFunctionsToAdd;
108 };
109 
visitAggregate(Visit visit,TIntermAggregate * node)110 bool ScalarizeTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
111 {
112     if (!shouldScalarize(node))
113     {
114         return true;
115     }
116 
117     TIntermTyped *replacement = createConstructor(node);
118     if (replacement != node)
119     {
120         queueReplacement(replacement, OriginalNode::IS_DROPPED);
121     }
122     // createConstructor already visits children
123     return false;
124 }
125 
shouldScalarize(TIntermTyped * typed)126 bool ScalarizeTraverser::shouldScalarize(TIntermTyped *typed)
127 {
128     TIntermAggregate *node = typed->getAsAggregate();
129     if (node == nullptr || node->getOp() != EOpConstruct)
130     {
131         return false;
132     }
133 
134     const TType &type                = node->getType();
135     const TIntermSequence &arguments = *node->getSequence();
136     const TType &arg0Type            = arguments[0]->getAsTyped()->getType();
137 
138     const bool isSingleVectorCast = arguments.size() == 1 && type.isVector() &&
139                                     arg0Type.isVector() &&
140                                     type.getNominalSize() == arg0Type.getNominalSize();
141     const bool isSingleMatrixCast = arguments.size() == 1 && type.isMatrix() &&
142                                     arg0Type.isMatrix() && type.getCols() == arg0Type.getCols() &&
143                                     type.getRows() == arg0Type.getRows();
144 
145     // Skip non-vector non-matrix constructors, as well as trivial constructors.
146     if (type.isArray() || type.getStruct() != nullptr || type.isScalar() || isSingleVectorCast ||
147         isSingleMatrixCast)
148     {
149         return false;
150     }
151 
152     return true;
153 }
154 
createHelper(TIntermAggregate * node)155 const TFunction *ScalarizeTraverser::createHelper(TIntermAggregate *node)
156 {
157     TFunction *helper =
158         new TFunction(mSymbolTable, kEmptyImmutableString, SymbolType::AngleInternal,
159                       GetHelperType(node->getType(), EvqTemporary), true);
160 
161     const TIntermSequence &arguments = *node->getSequence();
162     for (TIntermNode *arg : arguments)
163     {
164         const TType *argType = GetHelperType(arg->getAsTyped()->getType(), EvqParamIn);
165 
166         TVariable *argVar =
167             new TVariable(mSymbolTable, kEmptyImmutableString, argType, SymbolType::AngleInternal);
168         helper->addParameter(argVar);
169     }
170 
171     return helper;
172 }
173 
createHelperCall(TIntermAggregate * node,const TFunction * helper)174 TIntermTyped *ScalarizeTraverser::createHelperCall(TIntermAggregate *node, const TFunction *helper)
175 {
176     TIntermSequence callArgs;
177 
178     const TIntermSequence &arguments = *node->getSequence();
179     for (TIntermNode *arg : arguments)
180     {
181         // Note: createConstructor makes sure the arg is visited even if not constructor.
182         callArgs.push_back(createConstructor(arg->getAsTyped()));
183     }
184 
185     return TIntermAggregate::CreateFunctionCall(*helper, &callArgs);
186 }
187 
addHelperDefinition(const TFunction * helper,TIntermBlock * body)188 void ScalarizeTraverser::addHelperDefinition(const TFunction *helper, TIntermBlock *body)
189 {
190     mFunctionsToAdd.push_back(
191         new TIntermFunctionDefinition(new TIntermFunctionPrototype(helper), body));
192 }
193 
createConstructor(TIntermTyped * typed)194 TIntermTyped *ScalarizeTraverser::createConstructor(TIntermTyped *typed)
195 {
196     if (!shouldScalarize(typed))
197     {
198         typed->traverse(this);
199         return typed;
200     }
201 
202     TIntermAggregate *node           = typed->getAsAggregate();
203     const TType &type                = node->getType();
204     const TIntermSequence &arguments = *node->getSequence();
205     const TType &arg0Type            = arguments[0]->getAsTyped()->getType();
206 
207     const TFunction *helper = createHelper(node);
208     TIntermSequence constructorArgs;
209 
210     if (type.isVector())
211     {
212         if (arguments.size() == 1 && arg0Type.isScalar())
213         {
214             createConstructorVectorFromScalar(node, helper, &constructorArgs);
215         }
216         createConstructorVectorFromMultiple(node, helper, &constructorArgs);
217     }
218     else
219     {
220         ASSERT(type.isMatrix());
221 
222         if (arg0Type.isScalar() && arguments.size() == 1)
223         {
224             createConstructorMatrixFromScalar(node, helper, &constructorArgs);
225         }
226         if (arg0Type.isMatrix())
227         {
228             createConstructorMatrixFromMatrix(node, helper, &constructorArgs);
229         }
230         createConstructorMatrixFromVectors(node, helper, &constructorArgs);
231     }
232 
233     TIntermBlock *body = new TIntermBlock;
234     body->appendStatement(
235         new TIntermBranch(EOpReturn, TIntermAggregate::CreateConstructor(type, &constructorArgs)));
236     addHelperDefinition(helper, body);
237 
238     return createHelperCall(node, helper);
239 }
240 
241 // Extract enough scalar arguments from the arguments of helper to produce enough arguments for the
242 // constructor call (given in componentCount).
extractComponents(const TFunction * helper,size_t componentCount,TIntermSequence * componentsOut)243 void ScalarizeTraverser::extractComponents(const TFunction *helper,
244                                            size_t componentCount,
245                                            TIntermSequence *componentsOut)
246 {
247     for (size_t argumentIndex = 0;
248          argumentIndex < helper->getParamCount() && componentsOut->size() < componentCount;
249          ++argumentIndex)
250     {
251         TIntermTyped *argument    = new TIntermSymbol(helper->getParam(argumentIndex));
252         const TType &argumentType = argument->getType();
253 
254         if (argumentType.isScalar())
255         {
256             // For scalar parameters, there's nothing to do
257             componentsOut->push_back(argument);
258             continue;
259         }
260         if (argumentType.isVector())
261         {
262             // For vector parameters, take components out of the vector one by one.
263             for (uint8_t componentIndex = 0; componentIndex < argumentType.getNominalSize() &&
264                                              componentsOut->size() < componentCount;
265                  ++componentIndex)
266             {
267                 componentsOut->push_back(
268                     new TIntermSwizzle(argument->deepCopy(), {componentIndex}));
269             }
270             continue;
271         }
272 
273         ASSERT(argumentType.isMatrix());
274 
275         // For matrix parameters, take components out of the matrix one by one in column-major
276         // order.
277         for (uint8_t columnIndex = 0;
278              columnIndex < argumentType.getCols() && componentsOut->size() < componentCount;
279              ++columnIndex)
280         {
281             TIntermTyped *col = new TIntermBinary(EOpIndexDirect, argument->deepCopy(),
282                                                   CreateIndexNode(columnIndex));
283 
284             for (uint8_t componentIndex = 0;
285                  componentIndex < argumentType.getRows() && componentsOut->size() < componentCount;
286                  ++componentIndex)
287             {
288                 componentsOut->push_back(new TIntermSwizzle(col->deepCopy(), {componentIndex}));
289             }
290         }
291     }
292 }
293 
createConstructorVectorFromScalar(TIntermAggregate * node,const TFunction * helper,TIntermSequence * constructorArgsOut)294 void ScalarizeTraverser::createConstructorVectorFromScalar(TIntermAggregate *node,
295                                                            const TFunction *helper,
296                                                            TIntermSequence *constructorArgsOut)
297 {
298     ASSERT(helper->getParamCount() == 1);
299     TIntermTyped *scalar = new TIntermSymbol(helper->getParam(0));
300     const TType &type    = node->getType();
301 
302     // Replicate the single scalar argument as many times as necessary.
303     for (size_t index = 0; index < type.getNominalSize(); ++index)
304     {
305         constructorArgsOut->push_back(scalar->deepCopy());
306     }
307 }
308 
createConstructorVectorFromMultiple(TIntermAggregate * node,const TFunction * helper,TIntermSequence * constructorArgsOut)309 void ScalarizeTraverser::createConstructorVectorFromMultiple(TIntermAggregate *node,
310                                                              const TFunction *helper,
311                                                              TIntermSequence *constructorArgsOut)
312 {
313     extractComponents(helper, node->getType().getNominalSize(), constructorArgsOut);
314 }
315 
createConstructorMatrixFromScalar(TIntermAggregate * node,const TFunction * helper,TIntermSequence * constructorArgsOut)316 void ScalarizeTraverser::createConstructorMatrixFromScalar(TIntermAggregate *node,
317                                                            const TFunction *helper,
318                                                            TIntermSequence *constructorArgsOut)
319 {
320     ASSERT(helper->getParamCount() == 1);
321     TIntermTyped *scalar = new TIntermSymbol(helper->getParam(0));
322     const TType &type    = node->getType();
323 
324     // Create the scalar over the diagonal.  Every other element is 0.
325     for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
326     {
327         for (uint8_t rowIndex = 0; rowIndex < type.getRows(); ++rowIndex)
328         {
329             if (columnIndex == rowIndex)
330             {
331                 constructorArgsOut->push_back(scalar->deepCopy());
332             }
333             else
334             {
335                 ASSERT(type.getBasicType() == EbtFloat);
336                 constructorArgsOut->push_back(CreateFloatNode(0, type.getPrecision()));
337             }
338         }
339     }
340 }
341 
createConstructorMatrixFromVectors(TIntermAggregate * node,const TFunction * helper,TIntermSequence * constructorArgsOut)342 void ScalarizeTraverser::createConstructorMatrixFromVectors(TIntermAggregate *node,
343                                                             const TFunction *helper,
344                                                             TIntermSequence *constructorArgsOut)
345 {
346     const TType &type = node->getType();
347     extractComponents(helper, type.getCols() * type.getRows(), constructorArgsOut);
348 }
349 
createConstructorMatrixFromMatrix(TIntermAggregate * node,const TFunction * helper,TIntermSequence * constructorArgsOut)350 void ScalarizeTraverser::createConstructorMatrixFromMatrix(TIntermAggregate *node,
351                                                            const TFunction *helper,
352                                                            TIntermSequence *constructorArgsOut)
353 {
354     ASSERT(helper->getParamCount() == 1);
355     TIntermTyped *matrix = new TIntermSymbol(helper->getParam(0));
356     const TType &type    = node->getType();
357 
358     // The result is the identity matrix with the size of the result, superimposed by the input
359     for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
360     {
361         for (uint8_t rowIndex = 0; rowIndex < type.getRows(); ++rowIndex)
362         {
363             if (columnIndex < matrix->getType().getCols() && rowIndex < matrix->getType().getRows())
364             {
365                 TIntermTyped *col = new TIntermBinary(EOpIndexDirect, matrix->deepCopy(),
366                                                       CreateIndexNode(columnIndex));
367                 constructorArgsOut->push_back(
368                     new TIntermSwizzle(col, {static_cast<int>(rowIndex)}));
369             }
370             else
371             {
372                 ASSERT(type.getBasicType() == EbtFloat);
373                 constructorArgsOut->push_back(
374                     CreateFloatNode(columnIndex == rowIndex ? 1.0f : 0.0f, type.getPrecision()));
375             }
376         }
377     }
378 }
379 
update(TCompiler * compiler,TIntermBlock * root)380 bool ScalarizeTraverser::update(TCompiler *compiler, TIntermBlock *root)
381 {
382     // Insert any added function definitions at the tope of the block
383     root->insertChildNodes(0, mFunctionsToAdd);
384 
385     // Apply updates and validate
386     return updateTree(compiler, root);
387 }
388 }  // namespace
389 
ScalarizeVecAndMatConstructorArgs(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)390 bool ScalarizeVecAndMatConstructorArgs(TCompiler *compiler,
391                                        TIntermBlock *root,
392                                        TSymbolTable *symbolTable)
393 {
394     ScalarizeTraverser scalarizer(symbolTable);
395     root->traverse(&scalarizer);
396     return scalarizer.update(compiler, root);
397 }
398 }  // namespace sh
399