xref: /aosp_15_r20/external/angle/src/compiler/translator/msl/AstHelpers.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <cstring>
8 #include <numeric>
9 #include <unordered_map>
10 #include <unordered_set>
11 
12 #include "compiler/translator/msl/AstHelpers.h"
13 
14 using namespace sh;
15 
16 ////////////////////////////////////////////////////////////////////////////////
17 
CreateStructTypeVariable(TSymbolTable & symbolTable,const TStructure & structure)18 const TVariable &sh::CreateStructTypeVariable(TSymbolTable &symbolTable,
19                                               const TStructure &structure)
20 {
21     TType *type    = new TType(&structure, true);
22     TVariable *var = new TVariable(&symbolTable, ImmutableString(""), type, SymbolType::Empty);
23     return *var;
24 }
25 
CreateInstanceVariable(TSymbolTable & symbolTable,const TStructure & structure,const Name & name,TQualifier qualifier,const TSpan<const unsigned int> * arraySizes)26 const TVariable &sh::CreateInstanceVariable(TSymbolTable &symbolTable,
27                                             const TStructure &structure,
28                                             const Name &name,
29                                             TQualifier qualifier,
30                                             const TSpan<const unsigned int> *arraySizes)
31 {
32     TType *type = new TType(&structure, false);
33     type->setQualifier(qualifier);
34     if (arraySizes)
35     {
36         type->makeArrays(*arraySizes);
37     }
38     TVariable *var = new TVariable(&symbolTable, name.rawName(), type, name.symbolType());
39     return *var;
40 }
41 
AcquireFunctionExtras(TFunction & dest,const TFunction & src)42 static void AcquireFunctionExtras(TFunction &dest, const TFunction &src)
43 {
44     if (src.isDefined())
45     {
46         dest.setDefined();
47     }
48 
49     if (src.hasPrototypeDeclaration())
50     {
51         dest.setHasPrototypeDeclaration();
52     }
53 }
54 
CloneSequenceAndPrepend(const TIntermSequence & seq,TIntermNode & node)55 TIntermSequence &sh::CloneSequenceAndPrepend(const TIntermSequence &seq, TIntermNode &node)
56 {
57     TIntermSequence *newSeq = new TIntermSequence();
58     newSeq->push_back(&node);
59 
60     for (TIntermNode *oldNode : seq)
61     {
62         newSeq->push_back(oldNode);
63     }
64 
65     return *newSeq;
66 }
67 
AddParametersFrom(TFunction & dest,const TFunction & src)68 void sh::AddParametersFrom(TFunction &dest, const TFunction &src)
69 {
70     const size_t paramCount = src.getParamCount();
71     for (size_t i = 0; i < paramCount; ++i)
72     {
73         const TVariable *var = src.getParam(i);
74         dest.addParameter(var);
75     }
76 }
77 
CloneFunction(TSymbolTable & symbolTable,IdGen & idGen,const TFunction & oldFunc)78 const TFunction &sh::CloneFunction(TSymbolTable &symbolTable,
79                                    IdGen &idGen,
80                                    const TFunction &oldFunc)
81 {
82     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined);
83 
84     Name newName = idGen.createNewName(Name(oldFunc));
85 
86     TFunction &newFunc =
87         *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
88                        &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
89 
90     AcquireFunctionExtras(newFunc, oldFunc);
91     AddParametersFrom(newFunc, oldFunc);
92 
93     return newFunc;
94 }
95 
CloneFunctionAndPrependParam(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TVariable & newParam)96 const TFunction &sh::CloneFunctionAndPrependParam(TSymbolTable &symbolTable,
97                                                   IdGen *idGen,
98                                                   const TFunction &oldFunc,
99                                                   const TVariable &newParam)
100 {
101     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
102            oldFunc.symbolType() == SymbolType::AngleInternal);
103 
104     Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
105 
106     TFunction &newFunc =
107         *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
108                        &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
109 
110     AcquireFunctionExtras(newFunc, oldFunc);
111     newFunc.addParameter(&newParam);
112     AddParametersFrom(newFunc, oldFunc);
113 
114     return newFunc;
115 }
116 
CloneFunctionAndPrependTwoParams(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TVariable & newParam1,const TVariable & newParam2)117 const TFunction &sh::CloneFunctionAndPrependTwoParams(TSymbolTable &symbolTable,
118                                                       IdGen *idGen,
119                                                       const TFunction &oldFunc,
120                                                       const TVariable &newParam1,
121                                                       const TVariable &newParam2)
122 {
123     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
124            oldFunc.symbolType() == SymbolType::AngleInternal);
125 
126     Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
127 
128     TFunction &newFunc =
129         *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
130                        &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
131 
132     AcquireFunctionExtras(newFunc, oldFunc);
133     newFunc.addParameter(&newParam1);
134     newFunc.addParameter(&newParam2);
135     AddParametersFrom(newFunc, oldFunc);
136 
137     return newFunc;
138 }
139 
CloneFunctionAndAppendParams(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const std::vector<const TVariable * > & newParams)140 const TFunction &sh::CloneFunctionAndAppendParams(TSymbolTable &symbolTable,
141                                                   IdGen *idGen,
142                                                   const TFunction &oldFunc,
143                                                   const std::vector<const TVariable *> &newParams)
144 {
145     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
146            oldFunc.symbolType() == SymbolType::AngleInternal);
147 
148     Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
149 
150     TFunction &newFunc =
151         *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
152                        &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
153 
154     AcquireFunctionExtras(newFunc, oldFunc);
155     AddParametersFrom(newFunc, oldFunc);
156     for (const TVariable *param : newParams)
157     {
158         newFunc.addParameter(param);
159     }
160 
161     return newFunc;
162 }
163 
CloneFunctionAndChangeReturnType(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TStructure & newReturn)164 const TFunction &sh::CloneFunctionAndChangeReturnType(TSymbolTable &symbolTable,
165                                                       IdGen *idGen,
166                                                       const TFunction &oldFunc,
167                                                       const TStructure &newReturn)
168 {
169     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined);
170 
171     Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
172 
173     TType *newReturnType = new TType(&newReturn, true);
174     TFunction &newFunc   = *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
175                                           newReturnType, oldFunc.isKnownToNotHaveSideEffects());
176 
177     AcquireFunctionExtras(newFunc, oldFunc);
178     AddParametersFrom(newFunc, oldFunc);
179 
180     return newFunc;
181 }
182 
GetArg(const TIntermAggregate & call,size_t index)183 TIntermTyped &sh::GetArg(const TIntermAggregate &call, size_t index)
184 {
185     ASSERT(index < call.getChildCount());
186     TIntermNode *arg = call.getChildNode(index);
187     ASSERT(arg);
188     TIntermTyped *targ = arg->getAsTyped();
189     ASSERT(targ);
190     return *targ;
191 }
192 
SetArg(TIntermAggregate & call,size_t index,TIntermTyped & arg)193 void sh::SetArg(TIntermAggregate &call, size_t index, TIntermTyped &arg)
194 {
195     ASSERT(index < call.getChildCount());
196     (*call.getSequence())[index] = &arg;
197 }
198 
AccessField(const TVariable & structInstanceVar,const Name & name)199 TIntermBinary &sh::AccessField(const TVariable &structInstanceVar, const Name &name)
200 {
201     return AccessField(*new TIntermSymbol(&structInstanceVar), name);
202 }
203 
AccessField(TIntermTyped & object,const Name & name)204 TIntermBinary &sh::AccessField(TIntermTyped &object, const Name &name)
205 {
206     const TStructure *structure = object.getType().getStruct();
207     ASSERT(structure);
208     const TFieldList &fieldList = structure->fields();
209     for (int i = 0; i < static_cast<int>(fieldList.size()); ++i)
210     {
211         TField *current = fieldList[i];
212         if (Name(*current) == name)
213         {
214             return AccessFieldByIndex(object, i);
215         }
216     }
217     UNREACHABLE();
218     return AccessFieldByIndex(object, -1);
219 }
220 
AccessFieldByIndex(TIntermTyped & object,int index)221 TIntermBinary &sh::AccessFieldByIndex(TIntermTyped &object, int index)
222 {
223 #if defined(ANGLE_ENABLE_ASSERTS)
224     const TType &type = object.getType();
225     ASSERT(!type.isArray());
226     const TStructure *structure = type.getStruct();
227     ASSERT(structure);
228     ASSERT(0 <= index);
229     ASSERT(static_cast<size_t>(index) < structure->fields().size());
230 #endif
231 
232     return *new TIntermBinary(
233         TOperator::EOpIndexDirectStruct, &object,
234         new TIntermConstantUnion(new TConstantUnion(index), *new TType(TBasicType::EbtInt)));
235 }
236 
AccessIndex(TIntermTyped & indexableNode,int index)237 TIntermBinary &sh::AccessIndex(TIntermTyped &indexableNode, int index)
238 {
239 #if defined(ANGLE_ENABLE_ASSERTS)
240     const TType &type = indexableNode.getType();
241     ASSERT(type.isArray() || type.isVector() || type.isMatrix());
242 #endif
243 
244     TIntermBinary *accessNode = new TIntermBinary(
245         TOperator::EOpIndexDirect, &indexableNode,
246         new TIntermConstantUnion(new TConstantUnion(index), *new TType(TBasicType::EbtInt)));
247     return *accessNode;
248 }
249 
AccessIndex(TIntermTyped & node,const int * index)250 TIntermTyped &sh::AccessIndex(TIntermTyped &node, const int *index)
251 {
252     if (index)
253     {
254         return AccessIndex(node, *index);
255     }
256     return node;
257 }
258 
SubVector(TIntermTyped & vectorNode,int begin,int end)259 TIntermTyped &sh::SubVector(TIntermTyped &vectorNode, int begin, int end)
260 {
261     ASSERT(vectorNode.getType().isVector());
262     ASSERT(0 <= begin);
263     ASSERT(end <= 4);
264     ASSERT(begin <= end);
265     if (begin == 0 && end == vectorNode.getType().getNominalSize())
266     {
267         return vectorNode;
268     }
269     TVector<int> offsets(static_cast<size_t>(end - begin));
270     std::iota(offsets.begin(), offsets.end(), begin);
271     TIntermSwizzle *swizzle = new TIntermSwizzle(vectorNode.deepCopy(), offsets);
272     return *swizzle->fold(nullptr);  // Swizzles must always be folded to prevent double swizzles.
273 }
274 
IsScalarBasicType(const TType & type)275 bool sh::IsScalarBasicType(const TType &type)
276 {
277     if (!type.isScalar())
278     {
279         return false;
280     }
281     return HasScalarBasicType(type);
282 }
283 
IsVectorBasicType(const TType & type)284 bool sh::IsVectorBasicType(const TType &type)
285 {
286     if (!type.isVector())
287     {
288         return false;
289     }
290     return HasScalarBasicType(type);
291 }
292 
HasScalarBasicType(TBasicType type)293 bool sh::HasScalarBasicType(TBasicType type)
294 {
295     switch (type)
296     {
297         case TBasicType::EbtFloat:
298         case TBasicType::EbtInt:
299         case TBasicType::EbtUInt:
300         case TBasicType::EbtBool:
301             return true;
302 
303         default:
304             return false;
305     }
306 }
307 
HasScalarBasicType(const TType & type)308 bool sh::HasScalarBasicType(const TType &type)
309 {
310     return HasScalarBasicType(type.getBasicType());
311 }
312 
CloneType(const TType & type)313 TType &sh::CloneType(const TType &type)
314 {
315     TType &clone = *new TType(type);
316     return clone;
317 }
318 
InnermostType(const TType & type)319 TType &sh::InnermostType(const TType &type)
320 {
321     TType &inner = *new TType(type);
322     inner.toArrayBaseType();
323     return inner;
324 }
325 
DropColumns(const TType & matrixType)326 TType &sh::DropColumns(const TType &matrixType)
327 {
328     ASSERT(matrixType.isMatrix());
329     ASSERT(HasScalarBasicType(matrixType));
330 
331     TType &vectorType = *new TType(matrixType);
332     vectorType.toMatrixColumnType();
333     return vectorType;
334 }
335 
DropOuterDimension(const TType & arrayType)336 TType &sh::DropOuterDimension(const TType &arrayType)
337 {
338     ASSERT(arrayType.isArray());
339 
340     TType &innerType = *new TType(arrayType);
341     innerType.toArrayElementType();
342     return innerType;
343 }
344 
SetTypeDimsImpl(const TType & type,int primary,int secondary)345 static TType &SetTypeDimsImpl(const TType &type, int primary, int secondary)
346 {
347     ASSERT(1 < primary && primary <= 4);
348     ASSERT(1 <= secondary && secondary <= 4);
349     ASSERT(HasScalarBasicType(type));
350 
351     TType &newType = *new TType(type);
352     newType.setPrimarySize(primary);
353     newType.setSecondarySize(secondary);
354     return newType;
355 }
356 
SetVectorDim(const TType & type,int newDim)357 TType &sh::SetVectorDim(const TType &type, int newDim)
358 {
359     ASSERT(type.isRank0() || type.isVector());
360     return SetTypeDimsImpl(type, newDim, 1);
361 }
362 
SetMatrixRowDim(const TType & matrixType,int newDim)363 TType &sh::SetMatrixRowDim(const TType &matrixType, int newDim)
364 {
365     ASSERT(matrixType.isMatrix());
366     ASSERT(1 < newDim && newDim <= 4);
367     return SetTypeDimsImpl(matrixType, matrixType.getCols(), newDim);
368 }
369 
HasMatrixField(const TStructure & structure)370 bool sh::HasMatrixField(const TStructure &structure)
371 {
372     for (const TField *field : structure.fields())
373     {
374         const TType &type = *field->type();
375         if (type.isMatrix())
376         {
377             return true;
378         }
379     }
380     return false;
381 }
382 
HasArrayField(const TStructure & structure)383 bool sh::HasArrayField(const TStructure &structure)
384 {
385     for (const TField *field : structure.fields())
386     {
387         const TType &type = *field->type();
388         if (type.isArray())
389         {
390             return true;
391         }
392     }
393     return false;
394 }
395 
CoerceSimple(TBasicType toBasicType,TIntermTyped & fromNode,bool needsExplicitBoolCast)396 TIntermTyped &sh::CoerceSimple(TBasicType toBasicType,
397                                TIntermTyped &fromNode,
398                                bool needsExplicitBoolCast)
399 {
400     const TType &fromType = fromNode.getType();
401 
402     ASSERT(HasScalarBasicType(toBasicType));
403     ASSERT(HasScalarBasicType(fromType));
404     ASSERT(!fromType.isArray());
405 
406     const TBasicType fromBasicType = fromType.getBasicType();
407 
408     if (toBasicType != fromBasicType)
409     {
410         if (toBasicType == TBasicType::EbtBool && fromNode.isVector() && needsExplicitBoolCast)
411         {
412             switch (fromBasicType)
413             {
414                 case TBasicType::EbtFloat:
415                 case TBasicType::EbtInt:
416                 case TBasicType::EbtUInt:
417                 {
418                     TIntermSequence *argsSequence = new TIntermSequence();
419                     for (uint8_t i = 0; i < fromType.getNominalSize(); i++)
420                     {
421                         TIntermTyped &fromTypeSwizzle     = SubVector(fromNode, i, i + 1);
422                         TIntermAggregate *boolConstructor = TIntermAggregate::CreateConstructor(
423                             *new TType(toBasicType, 1, 1), new TIntermSequence{&fromTypeSwizzle});
424                         argsSequence->push_back(boolConstructor);
425                     }
426                     return *TIntermAggregate::CreateConstructor(
427                         *new TType(toBasicType, fromType.getNominalSize(),
428                                    fromType.getSecondarySize()),
429                         argsSequence);
430                 }
431 
432                 default:
433                     break;  // No explicit conversion needed
434             }
435         }
436 
437         return *TIntermAggregate::CreateConstructor(
438             *new TType(toBasicType, fromType.getNominalSize(), fromType.getSecondarySize()),
439             new TIntermSequence{&fromNode});
440     }
441     return fromNode;
442 }
443 
CoerceSimple(const TType & toType,TIntermTyped & fromNode,bool needsExplicitBoolCast)444 TIntermTyped &sh::CoerceSimple(const TType &toType,
445                                TIntermTyped &fromNode,
446                                bool needsExplicitBoolCast)
447 {
448     const TType &fromType = fromNode.getType();
449 
450     ASSERT(HasScalarBasicType(toType));
451     ASSERT(HasScalarBasicType(fromType));
452     ASSERT(toType.getNominalSize() == fromType.getNominalSize());
453     ASSERT(toType.getSecondarySize() == fromType.getSecondarySize());
454     ASSERT(!toType.isArray());
455     ASSERT(!fromType.isArray());
456 
457     const TBasicType toBasicType   = toType.getBasicType();
458     const TBasicType fromBasicType = fromType.getBasicType();
459 
460     if (toBasicType != fromBasicType)
461     {
462         if (toBasicType == TBasicType::EbtBool && fromNode.isVector() && needsExplicitBoolCast)
463         {
464             switch (fromBasicType)
465             {
466                 case TBasicType::EbtFloat:
467                 case TBasicType::EbtInt:
468                 case TBasicType::EbtUInt:
469                 {
470                     TIntermSequence *argsSequence = new TIntermSequence();
471                     for (uint8_t i = 0; i < fromType.getNominalSize(); i++)
472                     {
473                         TIntermTyped &fromTypeSwizzle     = SubVector(fromNode, i, i + 1);
474                         TIntermAggregate *boolConstructor = TIntermAggregate::CreateConstructor(
475                             *new TType(toBasicType, 1, 1), new TIntermSequence{&fromTypeSwizzle});
476                         argsSequence->push_back(boolConstructor);
477                     }
478                     return *TIntermAggregate::CreateConstructor(
479                         *new TType(toBasicType, fromType.getNominalSize(),
480                                    fromType.getSecondarySize()),
481                         new TIntermSequence{*argsSequence});
482                 }
483 
484                 default:
485                     break;  // No explicit conversion needed
486             }
487         }
488 
489         return *TIntermAggregate::CreateConstructor(toType, new TIntermSequence{&fromNode});
490     }
491     return fromNode;
492 }
493 
AsType(SymbolEnv & symbolEnv,const TType & toType,TIntermTyped & fromNode)494 TIntermTyped &sh::AsType(SymbolEnv &symbolEnv, const TType &toType, TIntermTyped &fromNode)
495 {
496     const TType &fromType = fromNode.getType();
497 
498     ASSERT(HasScalarBasicType(toType));
499     ASSERT(HasScalarBasicType(fromType));
500     ASSERT(!toType.isArray());
501     ASSERT(!fromType.isArray());
502 
503     if (toType == fromType)
504     {
505         return fromNode;
506     }
507     TemplateArg targ(toType);
508     return symbolEnv.callFunctionOverload(Name("as_type", SymbolType::BuiltIn), toType,
509                                           *new TIntermSequence{&fromNode}, 1, &targ);
510 }
511