1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Scalarize vector and matrix constructor args, so that vectors built from components don't have
7 // matrix arguments, and matrices built from components don't have vector arguments. This avoids
8 // driver bugs around vector and matrix constructors.
9 //
10
11 #include "compiler/translator/tree_ops/glsl/ScalarizeVecAndMatConstructorArgs.h"
12
13 #include "angle_gl.h"
14 #include "common/angleutils.h"
15 #include "compiler/translator/Compiler.h"
16 #include "compiler/translator/tree_util/IntermNode_util.h"
17 #include "compiler/translator/tree_util/IntermTraverse.h"
18
19 namespace sh
20 {
21
22 namespace
23 {
GetHelperType(const TType & type,TQualifier qualifier)24 const TType *GetHelperType(const TType &type, TQualifier qualifier)
25 {
26 // If the type does not have a precision, it means that non of the parameters of the constructor
27 // have precision (for example because they are constants, or bool), and there is any precision
28 // propagation happening from nearby operands. In that case, assign a highp precision to them;
29 // the driver will inline and eliminate the call anyway, and the precision does not affect
30 // anything.
31 constexpr TPrecision kDefaultPrecision = EbpHigh;
32
33 TType *newType = new TType(type.getBasicType(), type.getNominalSize(), type.getSecondarySize());
34 if (type.getBasicType() != EbtBool)
35 {
36 newType->setPrecision(type.getPrecision() != EbpUndefined ? type.getPrecision()
37 : kDefaultPrecision);
38 }
39 newType->setQualifier(qualifier);
40
41 return newType;
42 }
43
44 // Traverser that converts a vector or matrix constructor to one that only uses scalars. To support
45 // all the various places such a constructor could be found, a helper function is created for each
46 // such constructor. The helper function takes the constructor arguments and creates the object.
47 //
48 // Constructors that are transformed are:
49 //
50 // - vecN(scalar): translates to vecN(scalar, ..., scalar)
51 // - vecN(vec1, vec2, ...): translates to vecN(vec1.x, vec1.y, vec2.x, ...)
52 // - vecN(matrix): translates to vecN(matrix[0][0], matrix[0][1], ...)
53 // - matNxM(scalar): translates to matNxM(scalar, 0, ..., 0
54 // 0, scalar, ..., 0
55 // ...
56 // 0, 0, ..., scalar)
57 // - matNxM(vec1, vec2, ...): translates to matNxM(vec1.x, vec1.y, vec2.x, ...)
58 // - matNxM(matrixAxB): translates to matNxM(matrix[0][0], matrix[0][1], ..., 0
59 // matrix[1][0], matrix[1][1], ..., 0
60 // ...
61 // 0, 0, ..., 1)
62 //
63 class ScalarizeTraverser : public TIntermTraverser
64 {
65 public:
ScalarizeTraverser(TSymbolTable * symbolTable)66 ScalarizeTraverser(TSymbolTable *symbolTable)
67 : TIntermTraverser(true, false, false, symbolTable)
68 {}
69
70 bool update(TCompiler *compiler, TIntermBlock *root);
71
72 protected:
73 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
74
75 private:
76 bool shouldScalarize(TIntermTyped *node);
77
78 // Create a helper function that takes the same arguments as the constructor it replaces.
79 const TFunction *createHelper(TIntermAggregate *node);
80 TIntermTyped *createHelperCall(TIntermAggregate *node, const TFunction *helper);
81 void addHelperDefinition(const TFunction *helper, TIntermBlock *body);
82
83 // If given a constructor, convert it to a function call. Recursively processes constructor
84 // arguments. Otherwise, recursively visit the node.
85 TIntermTyped *createConstructor(TIntermTyped *node);
86
87 void extractComponents(const TFunction *helper,
88 size_t componentCount,
89 TIntermSequence *componentsOut);
90
91 void createConstructorVectorFromScalar(TIntermAggregate *node,
92 const TFunction *helper,
93 TIntermSequence *constructorArgsOut);
94 void createConstructorVectorFromMultiple(TIntermAggregate *node,
95 const TFunction *helper,
96 TIntermSequence *constructorArgsOut);
97 void createConstructorMatrixFromScalar(TIntermAggregate *node,
98 const TFunction *helper,
99 TIntermSequence *constructorArgsOut);
100 void createConstructorMatrixFromVectors(TIntermAggregate *node,
101 const TFunction *helper,
102 TIntermSequence *constructorArgsOut);
103 void createConstructorMatrixFromMatrix(TIntermAggregate *node,
104 const TFunction *helper,
105 TIntermSequence *constructorArgsOut);
106
107 TIntermSequence mFunctionsToAdd;
108 };
109
visitAggregate(Visit visit,TIntermAggregate * node)110 bool ScalarizeTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
111 {
112 if (!shouldScalarize(node))
113 {
114 return true;
115 }
116
117 TIntermTyped *replacement = createConstructor(node);
118 if (replacement != node)
119 {
120 queueReplacement(replacement, OriginalNode::IS_DROPPED);
121 }
122 // createConstructor already visits children
123 return false;
124 }
125
shouldScalarize(TIntermTyped * typed)126 bool ScalarizeTraverser::shouldScalarize(TIntermTyped *typed)
127 {
128 TIntermAggregate *node = typed->getAsAggregate();
129 if (node == nullptr || node->getOp() != EOpConstruct)
130 {
131 return false;
132 }
133
134 const TType &type = node->getType();
135 const TIntermSequence &arguments = *node->getSequence();
136 const TType &arg0Type = arguments[0]->getAsTyped()->getType();
137
138 const bool isSingleVectorCast = arguments.size() == 1 && type.isVector() &&
139 arg0Type.isVector() &&
140 type.getNominalSize() == arg0Type.getNominalSize();
141 const bool isSingleMatrixCast = arguments.size() == 1 && type.isMatrix() &&
142 arg0Type.isMatrix() && type.getCols() == arg0Type.getCols() &&
143 type.getRows() == arg0Type.getRows();
144
145 // Skip non-vector non-matrix constructors, as well as trivial constructors.
146 if (type.isArray() || type.getStruct() != nullptr || type.isScalar() || isSingleVectorCast ||
147 isSingleMatrixCast)
148 {
149 return false;
150 }
151
152 return true;
153 }
154
createHelper(TIntermAggregate * node)155 const TFunction *ScalarizeTraverser::createHelper(TIntermAggregate *node)
156 {
157 TFunction *helper =
158 new TFunction(mSymbolTable, kEmptyImmutableString, SymbolType::AngleInternal,
159 GetHelperType(node->getType(), EvqTemporary), true);
160
161 const TIntermSequence &arguments = *node->getSequence();
162 for (TIntermNode *arg : arguments)
163 {
164 const TType *argType = GetHelperType(arg->getAsTyped()->getType(), EvqParamIn);
165
166 TVariable *argVar =
167 new TVariable(mSymbolTable, kEmptyImmutableString, argType, SymbolType::AngleInternal);
168 helper->addParameter(argVar);
169 }
170
171 return helper;
172 }
173
createHelperCall(TIntermAggregate * node,const TFunction * helper)174 TIntermTyped *ScalarizeTraverser::createHelperCall(TIntermAggregate *node, const TFunction *helper)
175 {
176 TIntermSequence callArgs;
177
178 const TIntermSequence &arguments = *node->getSequence();
179 for (TIntermNode *arg : arguments)
180 {
181 // Note: createConstructor makes sure the arg is visited even if not constructor.
182 callArgs.push_back(createConstructor(arg->getAsTyped()));
183 }
184
185 return TIntermAggregate::CreateFunctionCall(*helper, &callArgs);
186 }
187
addHelperDefinition(const TFunction * helper,TIntermBlock * body)188 void ScalarizeTraverser::addHelperDefinition(const TFunction *helper, TIntermBlock *body)
189 {
190 mFunctionsToAdd.push_back(
191 new TIntermFunctionDefinition(new TIntermFunctionPrototype(helper), body));
192 }
193
createConstructor(TIntermTyped * typed)194 TIntermTyped *ScalarizeTraverser::createConstructor(TIntermTyped *typed)
195 {
196 if (!shouldScalarize(typed))
197 {
198 typed->traverse(this);
199 return typed;
200 }
201
202 TIntermAggregate *node = typed->getAsAggregate();
203 const TType &type = node->getType();
204 const TIntermSequence &arguments = *node->getSequence();
205 const TType &arg0Type = arguments[0]->getAsTyped()->getType();
206
207 const TFunction *helper = createHelper(node);
208 TIntermSequence constructorArgs;
209
210 if (type.isVector())
211 {
212 if (arguments.size() == 1 && arg0Type.isScalar())
213 {
214 createConstructorVectorFromScalar(node, helper, &constructorArgs);
215 }
216 createConstructorVectorFromMultiple(node, helper, &constructorArgs);
217 }
218 else
219 {
220 ASSERT(type.isMatrix());
221
222 if (arg0Type.isScalar() && arguments.size() == 1)
223 {
224 createConstructorMatrixFromScalar(node, helper, &constructorArgs);
225 }
226 if (arg0Type.isMatrix())
227 {
228 createConstructorMatrixFromMatrix(node, helper, &constructorArgs);
229 }
230 createConstructorMatrixFromVectors(node, helper, &constructorArgs);
231 }
232
233 TIntermBlock *body = new TIntermBlock;
234 body->appendStatement(
235 new TIntermBranch(EOpReturn, TIntermAggregate::CreateConstructor(type, &constructorArgs)));
236 addHelperDefinition(helper, body);
237
238 return createHelperCall(node, helper);
239 }
240
241 // Extract enough scalar arguments from the arguments of helper to produce enough arguments for the
242 // constructor call (given in componentCount).
extractComponents(const TFunction * helper,size_t componentCount,TIntermSequence * componentsOut)243 void ScalarizeTraverser::extractComponents(const TFunction *helper,
244 size_t componentCount,
245 TIntermSequence *componentsOut)
246 {
247 for (size_t argumentIndex = 0;
248 argumentIndex < helper->getParamCount() && componentsOut->size() < componentCount;
249 ++argumentIndex)
250 {
251 TIntermTyped *argument = new TIntermSymbol(helper->getParam(argumentIndex));
252 const TType &argumentType = argument->getType();
253
254 if (argumentType.isScalar())
255 {
256 // For scalar parameters, there's nothing to do
257 componentsOut->push_back(argument);
258 continue;
259 }
260 if (argumentType.isVector())
261 {
262 // For vector parameters, take components out of the vector one by one.
263 for (uint8_t componentIndex = 0; componentIndex < argumentType.getNominalSize() &&
264 componentsOut->size() < componentCount;
265 ++componentIndex)
266 {
267 componentsOut->push_back(
268 new TIntermSwizzle(argument->deepCopy(), {componentIndex}));
269 }
270 continue;
271 }
272
273 ASSERT(argumentType.isMatrix());
274
275 // For matrix parameters, take components out of the matrix one by one in column-major
276 // order.
277 for (uint8_t columnIndex = 0;
278 columnIndex < argumentType.getCols() && componentsOut->size() < componentCount;
279 ++columnIndex)
280 {
281 TIntermTyped *col = new TIntermBinary(EOpIndexDirect, argument->deepCopy(),
282 CreateIndexNode(columnIndex));
283
284 for (uint8_t componentIndex = 0;
285 componentIndex < argumentType.getRows() && componentsOut->size() < componentCount;
286 ++componentIndex)
287 {
288 componentsOut->push_back(new TIntermSwizzle(col->deepCopy(), {componentIndex}));
289 }
290 }
291 }
292 }
293
createConstructorVectorFromScalar(TIntermAggregate * node,const TFunction * helper,TIntermSequence * constructorArgsOut)294 void ScalarizeTraverser::createConstructorVectorFromScalar(TIntermAggregate *node,
295 const TFunction *helper,
296 TIntermSequence *constructorArgsOut)
297 {
298 ASSERT(helper->getParamCount() == 1);
299 TIntermTyped *scalar = new TIntermSymbol(helper->getParam(0));
300 const TType &type = node->getType();
301
302 // Replicate the single scalar argument as many times as necessary.
303 for (size_t index = 0; index < type.getNominalSize(); ++index)
304 {
305 constructorArgsOut->push_back(scalar->deepCopy());
306 }
307 }
308
createConstructorVectorFromMultiple(TIntermAggregate * node,const TFunction * helper,TIntermSequence * constructorArgsOut)309 void ScalarizeTraverser::createConstructorVectorFromMultiple(TIntermAggregate *node,
310 const TFunction *helper,
311 TIntermSequence *constructorArgsOut)
312 {
313 extractComponents(helper, node->getType().getNominalSize(), constructorArgsOut);
314 }
315
createConstructorMatrixFromScalar(TIntermAggregate * node,const TFunction * helper,TIntermSequence * constructorArgsOut)316 void ScalarizeTraverser::createConstructorMatrixFromScalar(TIntermAggregate *node,
317 const TFunction *helper,
318 TIntermSequence *constructorArgsOut)
319 {
320 ASSERT(helper->getParamCount() == 1);
321 TIntermTyped *scalar = new TIntermSymbol(helper->getParam(0));
322 const TType &type = node->getType();
323
324 // Create the scalar over the diagonal. Every other element is 0.
325 for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
326 {
327 for (uint8_t rowIndex = 0; rowIndex < type.getRows(); ++rowIndex)
328 {
329 if (columnIndex == rowIndex)
330 {
331 constructorArgsOut->push_back(scalar->deepCopy());
332 }
333 else
334 {
335 ASSERT(type.getBasicType() == EbtFloat);
336 constructorArgsOut->push_back(CreateFloatNode(0, type.getPrecision()));
337 }
338 }
339 }
340 }
341
createConstructorMatrixFromVectors(TIntermAggregate * node,const TFunction * helper,TIntermSequence * constructorArgsOut)342 void ScalarizeTraverser::createConstructorMatrixFromVectors(TIntermAggregate *node,
343 const TFunction *helper,
344 TIntermSequence *constructorArgsOut)
345 {
346 const TType &type = node->getType();
347 extractComponents(helper, type.getCols() * type.getRows(), constructorArgsOut);
348 }
349
createConstructorMatrixFromMatrix(TIntermAggregate * node,const TFunction * helper,TIntermSequence * constructorArgsOut)350 void ScalarizeTraverser::createConstructorMatrixFromMatrix(TIntermAggregate *node,
351 const TFunction *helper,
352 TIntermSequence *constructorArgsOut)
353 {
354 ASSERT(helper->getParamCount() == 1);
355 TIntermTyped *matrix = new TIntermSymbol(helper->getParam(0));
356 const TType &type = node->getType();
357
358 // The result is the identity matrix with the size of the result, superimposed by the input
359 for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
360 {
361 for (uint8_t rowIndex = 0; rowIndex < type.getRows(); ++rowIndex)
362 {
363 if (columnIndex < matrix->getType().getCols() && rowIndex < matrix->getType().getRows())
364 {
365 TIntermTyped *col = new TIntermBinary(EOpIndexDirect, matrix->deepCopy(),
366 CreateIndexNode(columnIndex));
367 constructorArgsOut->push_back(
368 new TIntermSwizzle(col, {static_cast<int>(rowIndex)}));
369 }
370 else
371 {
372 ASSERT(type.getBasicType() == EbtFloat);
373 constructorArgsOut->push_back(
374 CreateFloatNode(columnIndex == rowIndex ? 1.0f : 0.0f, type.getPrecision()));
375 }
376 }
377 }
378 }
379
update(TCompiler * compiler,TIntermBlock * root)380 bool ScalarizeTraverser::update(TCompiler *compiler, TIntermBlock *root)
381 {
382 // Insert any added function definitions at the tope of the block
383 root->insertChildNodes(0, mFunctionsToAdd);
384
385 // Apply updates and validate
386 return updateTree(compiler, root);
387 }
388 } // namespace
389
ScalarizeVecAndMatConstructorArgs(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)390 bool ScalarizeVecAndMatConstructorArgs(TCompiler *compiler,
391 TIntermBlock *root,
392 TSymbolTable *symbolTable)
393 {
394 ScalarizeTraverser scalarizer(symbolTable);
395 root->traverse(&scalarizer);
396 return scalarizer.update(compiler, root);
397 }
398 } // namespace sh
399