1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <cstring>
8 #include <numeric>
9 #include <unordered_map>
10 #include <unordered_set>
11
12 #include "compiler/translator/msl/AstHelpers.h"
13
14 using namespace sh;
15
16 ////////////////////////////////////////////////////////////////////////////////
17
CreateStructTypeVariable(TSymbolTable & symbolTable,const TStructure & structure)18 const TVariable &sh::CreateStructTypeVariable(TSymbolTable &symbolTable,
19 const TStructure &structure)
20 {
21 TType *type = new TType(&structure, true);
22 TVariable *var = new TVariable(&symbolTable, ImmutableString(""), type, SymbolType::Empty);
23 return *var;
24 }
25
CreateInstanceVariable(TSymbolTable & symbolTable,const TStructure & structure,const Name & name,TQualifier qualifier,const TSpan<const unsigned int> * arraySizes)26 const TVariable &sh::CreateInstanceVariable(TSymbolTable &symbolTable,
27 const TStructure &structure,
28 const Name &name,
29 TQualifier qualifier,
30 const TSpan<const unsigned int> *arraySizes)
31 {
32 TType *type = new TType(&structure, false);
33 type->setQualifier(qualifier);
34 if (arraySizes)
35 {
36 type->makeArrays(*arraySizes);
37 }
38 TVariable *var = new TVariable(&symbolTable, name.rawName(), type, name.symbolType());
39 return *var;
40 }
41
AcquireFunctionExtras(TFunction & dest,const TFunction & src)42 static void AcquireFunctionExtras(TFunction &dest, const TFunction &src)
43 {
44 if (src.isDefined())
45 {
46 dest.setDefined();
47 }
48
49 if (src.hasPrototypeDeclaration())
50 {
51 dest.setHasPrototypeDeclaration();
52 }
53 }
54
CloneSequenceAndPrepend(const TIntermSequence & seq,TIntermNode & node)55 TIntermSequence &sh::CloneSequenceAndPrepend(const TIntermSequence &seq, TIntermNode &node)
56 {
57 TIntermSequence *newSeq = new TIntermSequence();
58 newSeq->push_back(&node);
59
60 for (TIntermNode *oldNode : seq)
61 {
62 newSeq->push_back(oldNode);
63 }
64
65 return *newSeq;
66 }
67
AddParametersFrom(TFunction & dest,const TFunction & src)68 void sh::AddParametersFrom(TFunction &dest, const TFunction &src)
69 {
70 const size_t paramCount = src.getParamCount();
71 for (size_t i = 0; i < paramCount; ++i)
72 {
73 const TVariable *var = src.getParam(i);
74 dest.addParameter(var);
75 }
76 }
77
CloneFunction(TSymbolTable & symbolTable,IdGen & idGen,const TFunction & oldFunc)78 const TFunction &sh::CloneFunction(TSymbolTable &symbolTable,
79 IdGen &idGen,
80 const TFunction &oldFunc)
81 {
82 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined);
83
84 Name newName = idGen.createNewName(Name(oldFunc));
85
86 TFunction &newFunc =
87 *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
88 &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
89
90 AcquireFunctionExtras(newFunc, oldFunc);
91 AddParametersFrom(newFunc, oldFunc);
92
93 return newFunc;
94 }
95
CloneFunctionAndPrependParam(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TVariable & newParam)96 const TFunction &sh::CloneFunctionAndPrependParam(TSymbolTable &symbolTable,
97 IdGen *idGen,
98 const TFunction &oldFunc,
99 const TVariable &newParam)
100 {
101 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
102 oldFunc.symbolType() == SymbolType::AngleInternal);
103
104 Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
105
106 TFunction &newFunc =
107 *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
108 &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
109
110 AcquireFunctionExtras(newFunc, oldFunc);
111 newFunc.addParameter(&newParam);
112 AddParametersFrom(newFunc, oldFunc);
113
114 return newFunc;
115 }
116
CloneFunctionAndPrependTwoParams(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TVariable & newParam1,const TVariable & newParam2)117 const TFunction &sh::CloneFunctionAndPrependTwoParams(TSymbolTable &symbolTable,
118 IdGen *idGen,
119 const TFunction &oldFunc,
120 const TVariable &newParam1,
121 const TVariable &newParam2)
122 {
123 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
124 oldFunc.symbolType() == SymbolType::AngleInternal);
125
126 Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
127
128 TFunction &newFunc =
129 *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
130 &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
131
132 AcquireFunctionExtras(newFunc, oldFunc);
133 newFunc.addParameter(&newParam1);
134 newFunc.addParameter(&newParam2);
135 AddParametersFrom(newFunc, oldFunc);
136
137 return newFunc;
138 }
139
CloneFunctionAndAppendParams(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const std::vector<const TVariable * > & newParams)140 const TFunction &sh::CloneFunctionAndAppendParams(TSymbolTable &symbolTable,
141 IdGen *idGen,
142 const TFunction &oldFunc,
143 const std::vector<const TVariable *> &newParams)
144 {
145 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
146 oldFunc.symbolType() == SymbolType::AngleInternal);
147
148 Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
149
150 TFunction &newFunc =
151 *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
152 &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
153
154 AcquireFunctionExtras(newFunc, oldFunc);
155 AddParametersFrom(newFunc, oldFunc);
156 for (const TVariable *param : newParams)
157 {
158 newFunc.addParameter(param);
159 }
160
161 return newFunc;
162 }
163
CloneFunctionAndChangeReturnType(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TStructure & newReturn)164 const TFunction &sh::CloneFunctionAndChangeReturnType(TSymbolTable &symbolTable,
165 IdGen *idGen,
166 const TFunction &oldFunc,
167 const TStructure &newReturn)
168 {
169 ASSERT(oldFunc.symbolType() == SymbolType::UserDefined);
170
171 Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
172
173 TType *newReturnType = new TType(&newReturn, true);
174 TFunction &newFunc = *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
175 newReturnType, oldFunc.isKnownToNotHaveSideEffects());
176
177 AcquireFunctionExtras(newFunc, oldFunc);
178 AddParametersFrom(newFunc, oldFunc);
179
180 return newFunc;
181 }
182
GetArg(const TIntermAggregate & call,size_t index)183 TIntermTyped &sh::GetArg(const TIntermAggregate &call, size_t index)
184 {
185 ASSERT(index < call.getChildCount());
186 TIntermNode *arg = call.getChildNode(index);
187 ASSERT(arg);
188 TIntermTyped *targ = arg->getAsTyped();
189 ASSERT(targ);
190 return *targ;
191 }
192
SetArg(TIntermAggregate & call,size_t index,TIntermTyped & arg)193 void sh::SetArg(TIntermAggregate &call, size_t index, TIntermTyped &arg)
194 {
195 ASSERT(index < call.getChildCount());
196 (*call.getSequence())[index] = &arg;
197 }
198
AccessField(const TVariable & structInstanceVar,const Name & name)199 TIntermBinary &sh::AccessField(const TVariable &structInstanceVar, const Name &name)
200 {
201 return AccessField(*new TIntermSymbol(&structInstanceVar), name);
202 }
203
AccessField(TIntermTyped & object,const Name & name)204 TIntermBinary &sh::AccessField(TIntermTyped &object, const Name &name)
205 {
206 const TStructure *structure = object.getType().getStruct();
207 ASSERT(structure);
208 const TFieldList &fieldList = structure->fields();
209 for (int i = 0; i < static_cast<int>(fieldList.size()); ++i)
210 {
211 TField *current = fieldList[i];
212 if (Name(*current) == name)
213 {
214 return AccessFieldByIndex(object, i);
215 }
216 }
217 UNREACHABLE();
218 return AccessFieldByIndex(object, -1);
219 }
220
AccessFieldByIndex(TIntermTyped & object,int index)221 TIntermBinary &sh::AccessFieldByIndex(TIntermTyped &object, int index)
222 {
223 #if defined(ANGLE_ENABLE_ASSERTS)
224 const TType &type = object.getType();
225 ASSERT(!type.isArray());
226 const TStructure *structure = type.getStruct();
227 ASSERT(structure);
228 ASSERT(0 <= index);
229 ASSERT(static_cast<size_t>(index) < structure->fields().size());
230 #endif
231
232 return *new TIntermBinary(
233 TOperator::EOpIndexDirectStruct, &object,
234 new TIntermConstantUnion(new TConstantUnion(index), *new TType(TBasicType::EbtInt)));
235 }
236
AccessIndex(TIntermTyped & indexableNode,int index)237 TIntermBinary &sh::AccessIndex(TIntermTyped &indexableNode, int index)
238 {
239 #if defined(ANGLE_ENABLE_ASSERTS)
240 const TType &type = indexableNode.getType();
241 ASSERT(type.isArray() || type.isVector() || type.isMatrix());
242 #endif
243
244 TIntermBinary *accessNode = new TIntermBinary(
245 TOperator::EOpIndexDirect, &indexableNode,
246 new TIntermConstantUnion(new TConstantUnion(index), *new TType(TBasicType::EbtInt)));
247 return *accessNode;
248 }
249
AccessIndex(TIntermTyped & node,const int * index)250 TIntermTyped &sh::AccessIndex(TIntermTyped &node, const int *index)
251 {
252 if (index)
253 {
254 return AccessIndex(node, *index);
255 }
256 return node;
257 }
258
SubVector(TIntermTyped & vectorNode,int begin,int end)259 TIntermTyped &sh::SubVector(TIntermTyped &vectorNode, int begin, int end)
260 {
261 ASSERT(vectorNode.getType().isVector());
262 ASSERT(0 <= begin);
263 ASSERT(end <= 4);
264 ASSERT(begin <= end);
265 if (begin == 0 && end == vectorNode.getType().getNominalSize())
266 {
267 return vectorNode;
268 }
269 TVector<int> offsets(static_cast<size_t>(end - begin));
270 std::iota(offsets.begin(), offsets.end(), begin);
271 TIntermSwizzle *swizzle = new TIntermSwizzle(vectorNode.deepCopy(), offsets);
272 return *swizzle->fold(nullptr); // Swizzles must always be folded to prevent double swizzles.
273 }
274
IsScalarBasicType(const TType & type)275 bool sh::IsScalarBasicType(const TType &type)
276 {
277 if (!type.isScalar())
278 {
279 return false;
280 }
281 return HasScalarBasicType(type);
282 }
283
IsVectorBasicType(const TType & type)284 bool sh::IsVectorBasicType(const TType &type)
285 {
286 if (!type.isVector())
287 {
288 return false;
289 }
290 return HasScalarBasicType(type);
291 }
292
HasScalarBasicType(TBasicType type)293 bool sh::HasScalarBasicType(TBasicType type)
294 {
295 switch (type)
296 {
297 case TBasicType::EbtFloat:
298 case TBasicType::EbtInt:
299 case TBasicType::EbtUInt:
300 case TBasicType::EbtBool:
301 return true;
302
303 default:
304 return false;
305 }
306 }
307
HasScalarBasicType(const TType & type)308 bool sh::HasScalarBasicType(const TType &type)
309 {
310 return HasScalarBasicType(type.getBasicType());
311 }
312
CloneType(const TType & type)313 TType &sh::CloneType(const TType &type)
314 {
315 TType &clone = *new TType(type);
316 return clone;
317 }
318
InnermostType(const TType & type)319 TType &sh::InnermostType(const TType &type)
320 {
321 TType &inner = *new TType(type);
322 inner.toArrayBaseType();
323 return inner;
324 }
325
DropColumns(const TType & matrixType)326 TType &sh::DropColumns(const TType &matrixType)
327 {
328 ASSERT(matrixType.isMatrix());
329 ASSERT(HasScalarBasicType(matrixType));
330
331 TType &vectorType = *new TType(matrixType);
332 vectorType.toMatrixColumnType();
333 return vectorType;
334 }
335
DropOuterDimension(const TType & arrayType)336 TType &sh::DropOuterDimension(const TType &arrayType)
337 {
338 ASSERT(arrayType.isArray());
339
340 TType &innerType = *new TType(arrayType);
341 innerType.toArrayElementType();
342 return innerType;
343 }
344
SetTypeDimsImpl(const TType & type,int primary,int secondary)345 static TType &SetTypeDimsImpl(const TType &type, int primary, int secondary)
346 {
347 ASSERT(1 < primary && primary <= 4);
348 ASSERT(1 <= secondary && secondary <= 4);
349 ASSERT(HasScalarBasicType(type));
350
351 TType &newType = *new TType(type);
352 newType.setPrimarySize(primary);
353 newType.setSecondarySize(secondary);
354 return newType;
355 }
356
SetVectorDim(const TType & type,int newDim)357 TType &sh::SetVectorDim(const TType &type, int newDim)
358 {
359 ASSERT(type.isRank0() || type.isVector());
360 return SetTypeDimsImpl(type, newDim, 1);
361 }
362
SetMatrixRowDim(const TType & matrixType,int newDim)363 TType &sh::SetMatrixRowDim(const TType &matrixType, int newDim)
364 {
365 ASSERT(matrixType.isMatrix());
366 ASSERT(1 < newDim && newDim <= 4);
367 return SetTypeDimsImpl(matrixType, matrixType.getCols(), newDim);
368 }
369
HasMatrixField(const TStructure & structure)370 bool sh::HasMatrixField(const TStructure &structure)
371 {
372 for (const TField *field : structure.fields())
373 {
374 const TType &type = *field->type();
375 if (type.isMatrix())
376 {
377 return true;
378 }
379 }
380 return false;
381 }
382
HasArrayField(const TStructure & structure)383 bool sh::HasArrayField(const TStructure &structure)
384 {
385 for (const TField *field : structure.fields())
386 {
387 const TType &type = *field->type();
388 if (type.isArray())
389 {
390 return true;
391 }
392 }
393 return false;
394 }
395
CoerceSimple(TBasicType toBasicType,TIntermTyped & fromNode,bool needsExplicitBoolCast)396 TIntermTyped &sh::CoerceSimple(TBasicType toBasicType,
397 TIntermTyped &fromNode,
398 bool needsExplicitBoolCast)
399 {
400 const TType &fromType = fromNode.getType();
401
402 ASSERT(HasScalarBasicType(toBasicType));
403 ASSERT(HasScalarBasicType(fromType));
404 ASSERT(!fromType.isArray());
405
406 const TBasicType fromBasicType = fromType.getBasicType();
407
408 if (toBasicType != fromBasicType)
409 {
410 if (toBasicType == TBasicType::EbtBool && fromNode.isVector() && needsExplicitBoolCast)
411 {
412 switch (fromBasicType)
413 {
414 case TBasicType::EbtFloat:
415 case TBasicType::EbtInt:
416 case TBasicType::EbtUInt:
417 {
418 TIntermSequence *argsSequence = new TIntermSequence();
419 for (uint8_t i = 0; i < fromType.getNominalSize(); i++)
420 {
421 TIntermTyped &fromTypeSwizzle = SubVector(fromNode, i, i + 1);
422 TIntermAggregate *boolConstructor = TIntermAggregate::CreateConstructor(
423 *new TType(toBasicType, 1, 1), new TIntermSequence{&fromTypeSwizzle});
424 argsSequence->push_back(boolConstructor);
425 }
426 return *TIntermAggregate::CreateConstructor(
427 *new TType(toBasicType, fromType.getNominalSize(),
428 fromType.getSecondarySize()),
429 argsSequence);
430 }
431
432 default:
433 break; // No explicit conversion needed
434 }
435 }
436
437 return *TIntermAggregate::CreateConstructor(
438 *new TType(toBasicType, fromType.getNominalSize(), fromType.getSecondarySize()),
439 new TIntermSequence{&fromNode});
440 }
441 return fromNode;
442 }
443
CoerceSimple(const TType & toType,TIntermTyped & fromNode,bool needsExplicitBoolCast)444 TIntermTyped &sh::CoerceSimple(const TType &toType,
445 TIntermTyped &fromNode,
446 bool needsExplicitBoolCast)
447 {
448 const TType &fromType = fromNode.getType();
449
450 ASSERT(HasScalarBasicType(toType));
451 ASSERT(HasScalarBasicType(fromType));
452 ASSERT(toType.getNominalSize() == fromType.getNominalSize());
453 ASSERT(toType.getSecondarySize() == fromType.getSecondarySize());
454 ASSERT(!toType.isArray());
455 ASSERT(!fromType.isArray());
456
457 const TBasicType toBasicType = toType.getBasicType();
458 const TBasicType fromBasicType = fromType.getBasicType();
459
460 if (toBasicType != fromBasicType)
461 {
462 if (toBasicType == TBasicType::EbtBool && fromNode.isVector() && needsExplicitBoolCast)
463 {
464 switch (fromBasicType)
465 {
466 case TBasicType::EbtFloat:
467 case TBasicType::EbtInt:
468 case TBasicType::EbtUInt:
469 {
470 TIntermSequence *argsSequence = new TIntermSequence();
471 for (uint8_t i = 0; i < fromType.getNominalSize(); i++)
472 {
473 TIntermTyped &fromTypeSwizzle = SubVector(fromNode, i, i + 1);
474 TIntermAggregate *boolConstructor = TIntermAggregate::CreateConstructor(
475 *new TType(toBasicType, 1, 1), new TIntermSequence{&fromTypeSwizzle});
476 argsSequence->push_back(boolConstructor);
477 }
478 return *TIntermAggregate::CreateConstructor(
479 *new TType(toBasicType, fromType.getNominalSize(),
480 fromType.getSecondarySize()),
481 new TIntermSequence{*argsSequence});
482 }
483
484 default:
485 break; // No explicit conversion needed
486 }
487 }
488
489 return *TIntermAggregate::CreateConstructor(toType, new TIntermSequence{&fromNode});
490 }
491 return fromNode;
492 }
493
AsType(SymbolEnv & symbolEnv,const TType & toType,TIntermTyped & fromNode)494 TIntermTyped &sh::AsType(SymbolEnv &symbolEnv, const TType &toType, TIntermTyped &fromNode)
495 {
496 const TType &fromType = fromNode.getType();
497
498 ASSERT(HasScalarBasicType(toType));
499 ASSERT(HasScalarBasicType(fromType));
500 ASSERT(!toType.isArray());
501 ASSERT(!fromType.isArray());
502
503 if (toType == fromType)
504 {
505 return fromNode;
506 }
507 TemplateArg targ(toType);
508 return symbolEnv.callFunctionOverload(Name("as_type", SymbolType::BuiltIn), toType,
509 *new TIntermSequence{&fromNode}, 1, &targ);
510 }
511