xref: /aosp_15_r20/external/angle/src/compiler/translator/msl/ModifyStruct.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <algorithm>
8 #include <cstring>
9 #include <functional>
10 #include <numeric>
11 #include <unordered_map>
12 #include <unordered_set>
13 
14 #include "compiler/translator/Compiler.h"
15 #include "compiler/translator/ImmutableStringBuilder.h"
16 #include "compiler/translator/msl/AstHelpers.h"
17 #include "compiler/translator/msl/ModifyStruct.h"
18 #include "compiler/translator/msl/TranslatorMSL.h"
19 
20 using namespace sh;
21 
22 ////////////////////////////////////////////////////////////////////////////////
23 
size() const24 size_t ModifiedStructMachineries::size() const
25 {
26     return ordering.size();
27 }
28 
at(size_t index) const29 const ModifiedStructMachinery &ModifiedStructMachineries::at(size_t index) const
30 {
31     ASSERT(index < size());
32     const TStructure *s              = ordering[index];
33     const ModifiedStructMachinery *m = find(*s);
34     ASSERT(m);
35     return *m;
36 }
37 
find(const TStructure & s) const38 const ModifiedStructMachinery *ModifiedStructMachineries::find(const TStructure &s) const
39 {
40     auto iter = originalToMachinery.find(&s);
41     if (iter == originalToMachinery.end())
42     {
43         return nullptr;
44     }
45     return &iter->second;
46 }
47 
insert(const TStructure & s,const ModifiedStructMachinery & machinery)48 void ModifiedStructMachineries::insert(const TStructure &s,
49                                        const ModifiedStructMachinery &machinery)
50 {
51     ASSERT(!find(s));
52     originalToMachinery[&s] = machinery;
53     ordering.push_back(&s);
54 }
55 
56 ////////////////////////////////////////////////////////////////////////////////
57 
58 namespace
59 {
60 
Flatten(SymbolEnv & symbolEnv,TIntermTyped & node)61 TIntermTyped &Flatten(SymbolEnv &symbolEnv, TIntermTyped &node)
62 {
63     auto &type = node.getType();
64     ASSERT(type.isArray());
65 
66     auto &retType = InnermostType(type);
67     retType.makeArray(1);
68 
69     return symbolEnv.callFunctionOverload(Name("flatten"), retType, *new TIntermSequence{&node});
70 }
71 
72 struct FlattenArray
73 {};
74 
75 struct PathItem
76 {
77     enum class Type
78     {
79         Field,         // Struct field indexing.
80         Index,         // Array, vector, or matrix indexing.
81         FlattenArray,  // Array of any rank -> pointer of innermost type.
82     };
83 
PathItem__anon192f35d00111::PathItem84     PathItem(const TField &field) : field(&field), type(Type::Field) {}
PathItem__anon192f35d00111::PathItem85     PathItem(int index) : index(index), type(Type::Index) {}
PathItem__anon192f35d00111::PathItem86     PathItem(unsigned index) : PathItem(static_cast<int>(index)) {}
PathItem__anon192f35d00111::PathItem87     PathItem(FlattenArray flatten) : type(Type::FlattenArray) {}
88 
89     union
90     {
91         const TField *field;
92         int index;
93     };
94     Type type;
95 };
96 
BuildPathAccess(SymbolEnv & symbolEnv,const TVariable & var,const std::vector<PathItem> & path)97 TIntermTyped &BuildPathAccess(SymbolEnv &symbolEnv,
98                               const TVariable &var,
99                               const std::vector<PathItem> &path)
100 {
101     TIntermTyped *curr = new TIntermSymbol(&var);
102     for (const PathItem &item : path)
103     {
104         switch (item.type)
105         {
106             case PathItem::Type::Field:
107                 curr = &AccessField(*curr, Name(*item.field));
108                 break;
109             case PathItem::Type::Index:
110                 curr = &AccessIndex(*curr, item.index);
111                 break;
112             case PathItem::Type::FlattenArray:
113             {
114                 curr = &Flatten(symbolEnv, *curr);
115             }
116             break;
117         }
118     }
119     return *curr;
120 }
121 
122 ////////////////////////////////////////////////////////////////////////////////
123 
124 using OriginalParam = const TVariable &;
125 using ModifiedParam = const TVariable &;
126 
127 using OriginalAccess = TIntermTyped;
128 using ModifiedAccess = TIntermTyped;
129 
130 struct Access
131 {
132     OriginalAccess &original;
133     ModifiedAccess &modified;
134 
135     struct Env
136     {
137         const ConvertType type;
138     };
139 };
140 
141 using ConversionFunc = std::function<Access(Access::Env &, OriginalAccess &, ModifiedAccess &)>;
142 
143 class ConvertStructState : angle::NonCopyable
144 {
145   private:
146     struct ConversionInfo
147     {
148         ConversionFunc stdFunc;
149         const TFunction *astFunc;
150         std::vector<PathItem> pathItems;
151         Name modifiedFieldName;
152     };
153 
154   public:
ConvertStructState(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen,const ModifyStructConfig & config,ModifiedStructMachineries & outMachineries,const bool isUBO,const bool useAttributeAliasing)155     ConvertStructState(TCompiler &compiler,
156                        SymbolEnv &symbolEnv,
157                        IdGen &idGen,
158                        const ModifyStructConfig &config,
159                        ModifiedStructMachineries &outMachineries,
160                        const bool isUBO,
161                        const bool useAttributeAliasing)
162         : mCompiler(compiler),
163           config(config),
164           symbolEnv(symbolEnv),
165           modifiedFields(*new TFieldList()),
166           symbolTable(symbolEnv.symbolTable()),
167           idGen(idGen),
168           outMachineries(outMachineries),
169           isUBO(isUBO),
170           useAttributeAliasing(useAttributeAliasing)
171     {}
172 
~ConvertStructState()173     ~ConvertStructState()
174     {
175     }
176 
publish(const TStructure & originalStruct,const Name & modifiedStructName)177     void publish(const TStructure &originalStruct, const Name &modifiedStructName)
178     {
179         const bool isOriginalToModified = config.convertType == ConvertType::OriginalToModified;
180 
181         auto &modifiedStruct = *new TStructure(&symbolTable, modifiedStructName.rawName(),
182                                                &modifiedFields, modifiedStructName.symbolType());
183 
184         auto &func = *new TFunction(
185             &symbolTable,
186             idGen.createNewName(isOriginalToModified ? "originalToModified" : "modifiedToOriginal")
187                 .rawName(),
188             SymbolType::AngleInternal, new TType(TBasicType::EbtVoid), false);
189 
190         OriginalParam originalParam =
191             CreateInstanceVariable(symbolTable, originalStruct, Name("original"));
192         ModifiedParam modifiedParam =
193             CreateInstanceVariable(symbolTable, modifiedStruct, Name("modified"));
194 
195         symbolEnv.markAsReference(originalParam, AddressSpace::Thread);
196         symbolEnv.markAsReference(modifiedParam, config.externalAddressSpace);
197         if (isOriginalToModified)
198         {
199             func.addParameter(&originalParam);
200             func.addParameter(&modifiedParam);
201         }
202         else
203         {
204             func.addParameter(&modifiedParam);
205             func.addParameter(&originalParam);
206         }
207 
208         TIntermBlock &body = *new TIntermBlock();
209 
210         Access::Env env{config.convertType};
211 
212         for (ConversionInfo &info : conversionInfos)
213         {
214             auto convert = [&](OriginalAccess &original, ModifiedAccess &modified) {
215                 if (info.astFunc)
216                 {
217                     ASSERT(!info.stdFunc);
218                     TIntermTyped &src  = isOriginalToModified ? modified : original;
219                     TIntermTyped &dest = isOriginalToModified ? original : modified;
220                     body.appendStatement(TIntermAggregate::CreateFunctionCall(
221                         *info.astFunc, new TIntermSequence{&dest, &src}));
222                 }
223                 else
224                 {
225                     ASSERT(info.stdFunc);
226                     Access access      = info.stdFunc(env, original, modified);
227                     TIntermTyped &src  = isOriginalToModified ? access.original : access.modified;
228                     TIntermTyped &dest = isOriginalToModified ? access.modified : access.original;
229                     body.appendStatement(new TIntermBinary(TOperator::EOpAssign, &dest, &src));
230                 }
231             };
232             OriginalAccess *original = &BuildPathAccess(symbolEnv, originalParam, info.pathItems);
233             ModifiedAccess *modified = &AccessField(modifiedParam, info.modifiedFieldName);
234             if (useAttributeAliasing)
235             {
236                 std::ostringstream aliasedName;
237                 aliasedName << "ANGLE_ALIASED_" << info.modifiedFieldName;
238 
239                 TType *placeholderType = new TType(modified->getType());
240                 placeholderType->setQualifier(EvqSpecConst);
241 
242                 modified = new TIntermSymbol(
243                     new TVariable(&symbolTable, sh::ImmutableString(aliasedName.str()),
244                                   placeholderType, SymbolType::AngleInternal));
245             }
246             const TType ot = original->getType();
247             const TType mt = modified->getType();
248             ASSERT(ot.isArray() == mt.isArray());
249 
250             // Clip distance output uses float[n] type, so the field must be assigned per-element
251             // when filling the modified struct. Explicit path name is used because original types
252             // are not available here.
253             if (ot.isArray() &&
254                 (ot.getLayoutQualifier().matrixPacking == EmpRowMajor || ot != mt ||
255                  info.modifiedFieldName == Name("gl_ClipDistance", SymbolType::BuiltIn)))
256             {
257                 ASSERT(ot.getArraySizes() == mt.getArraySizes());
258                 if (ot.isArrayOfArrays())
259                 {
260                     original = &Flatten(symbolEnv, *original);
261                     modified = &Flatten(symbolEnv, *modified);
262                 }
263                 const int volume = static_cast<int>(ot.getArraySizeProduct());
264                 for (int i = 0; i < volume; ++i)
265                 {
266                     if (i != 0)
267                     {
268                         original = original->deepCopy();
269                         modified = modified->deepCopy();
270                     }
271                     OriginalAccess &o = AccessIndex(*original, i);
272                     OriginalAccess &m = AccessIndex(*modified, i);
273                     convert(o, m);
274                 }
275             }
276             else
277             {
278                 convert(*original, *modified);
279             }
280         }
281 
282         auto *funcProto = new TIntermFunctionPrototype(&func);
283         auto *funcDef   = new TIntermFunctionDefinition(funcProto, &body);
284 
285         ModifiedStructMachinery machinery;
286         machinery.modifiedStruct                   = &modifiedStruct;
287         machinery.getConverter(config.convertType) = funcDef;
288 
289         outMachineries.insert(originalStruct, machinery);
290     }
291 
rootFieldName() const292     ImmutableString rootFieldName() const
293     {
294         if (!pathItems.empty())
295         {
296             if (pathItems[0].type == PathItem::Type::Field)
297             {
298                 return pathItems[0].field->name();
299             }
300         }
301         UNREACHABLE();
302         return kEmptyImmutableString;
303     }
304 
pushPath(PathItem const & item)305     void pushPath(PathItem const &item) { pathItems.push_back(item); }
306 
popPath()307     void popPath()
308     {
309         ASSERT(!pathItems.empty());
310         pathItems.pop_back();
311         if (pathItems.empty())
312         {
313             // Next push will start a new root output variable to linearize.
314             mSubfieldIndex = 0;
315         }
316     }
317 
finalize(const bool allowPadding)318     void finalize(const bool allowPadding)
319     {
320         ASSERT(!finalized);
321         finalized = true;
322         introducePacking();
323         ASSERT(metalLayoutTotal == Layout::Identity());
324         // Only pad substructs. We don't want to pad the structure that contains all the UBOs, only
325         // individual UBOs.
326         if (allowPadding)
327             introducePadding();
328     }
329 
addModifiedField(const TField & field,TType & newType,TLayoutBlockStorage storage,TLayoutMatrixPacking packing,const AddressSpace * addressSpace)330     void addModifiedField(const TField &field,
331                           TType &newType,
332                           TLayoutBlockStorage storage,
333                           TLayoutMatrixPacking packing,
334                           const AddressSpace *addressSpace)
335     {
336         TLayoutQualifier layoutQualifier = newType.getLayoutQualifier();
337         layoutQualifier.blockStorage     = storage;
338         layoutQualifier.matrixPacking    = packing;
339         newType.setLayoutQualifier(layoutQualifier);
340         sh::ImmutableString newName  = field.name();
341         sh::SymbolType newSymbolType = field.symbolType();
342         if (pathItems.size() > 1)
343         {
344             // Current state is linearizing a root input field into multiple modified fields. The
345             // new fields need unique names. Generate the new names into AngleInternal namespace.
346             // The user could choose a clashing name in UserDefined namespace.
347             newSymbolType = SymbolType::AngleInternal;
348             // The user specified root field name is currently used as the basis for the MSL vs-fs
349             // interface matching. The field linearization itself is deterministic, so subfield
350             // index is sufficient to define all the entries in MSL interface in all the compatible
351             // VS and FS MSL programs.
352             newName = BuildConcatenatedImmutableString(rootFieldName(), '_', mSubfieldIndex);
353             ++mSubfieldIndex;
354         }
355         TField *modifiedField = new TField(&newType, newName, field.line(), newSymbolType);
356         if (addressSpace)
357         {
358             symbolEnv.markAsPointer(*modifiedField, *addressSpace);
359         }
360         if (symbolEnv.isUBO(field))
361         {
362             symbolEnv.markAsUBO(*modifiedField);
363         }
364         modifiedFields.push_back(modifiedField);
365     }
366 
addConversion(const ConversionFunc & func)367     void addConversion(const ConversionFunc &func)
368     {
369         ASSERT(!modifiedFields.empty());
370         conversionInfos.push_back({func, nullptr, pathItems, Name(*modifiedFields.back())});
371     }
372 
addConversion(const TFunction & func)373     void addConversion(const TFunction &func)
374     {
375         ASSERT(!modifiedFields.empty());
376         conversionInfos.push_back({{}, &func, pathItems, Name(*modifiedFields.back())});
377     }
378 
hasPacking() const379     bool hasPacking() const { return containsPacked; }
380 
hasPadding() const381     bool hasPadding() const { return padFieldCount > 0; }
382 
recurse(const TStructure & structure,ModifiedStructMachinery & outMachinery,const bool isUBORecurse)383     bool recurse(const TStructure &structure,
384                  ModifiedStructMachinery &outMachinery,
385                  const bool isUBORecurse)
386     {
387         const ModifiedStructMachinery *m = outMachineries.find(structure);
388         if (m == nullptr)
389         {
390             TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
391             reflection->addOriginalName(structure.uniqueId().get(), structure.name().data());
392             const Name name = idGen.createNewName(structure.name().data());
393             if (!TryCreateModifiedStruct(mCompiler, symbolEnv, idGen, config, structure, name,
394                                          outMachineries, isUBORecurse, config.allowPadding, false))
395             {
396                 return false;
397             }
398             m = outMachineries.find(structure);
399             ASSERT(m);
400         }
401         outMachinery = *m;
402         return true;
403     }
404 
getIsUBO() const405     bool getIsUBO() const { return isUBO; }
406 
407   private:
addPadding(size_t padAmount,bool updateLayout)408     void addPadding(size_t padAmount, bool updateLayout)
409     {
410         if (padAmount == 0)
411         {
412             return;
413         }
414 
415         const size_t begin = modifiedFields.size();
416 
417         // Iteratively adding in scalar or vector padding because some struct types will not
418         // allow matrix or array members.
419         while (padAmount > 0)
420         {
421             TType *padType;
422             if (padAmount >= 16)
423             {
424                 padAmount -= 16;
425                 padType = new TType(TBasicType::EbtFloat, 4);
426             }
427             else if (padAmount >= 8)
428             {
429                 padAmount -= 8;
430                 padType = new TType(TBasicType::EbtFloat, 2);
431             }
432             else if (padAmount >= 4)
433             {
434                 padAmount -= 4;
435                 padType = new TType(TBasicType::EbtFloat);
436             }
437             else if (padAmount >= 2)
438             {
439                 padAmount -= 2;
440                 padType = new TType(TBasicType::EbtBool, 2);
441             }
442             else
443             {
444                 ASSERT(padAmount == 1);
445                 padAmount -= 1;
446                 padType = new TType(TBasicType::EbtBool);
447             }
448 
449             if (padType->getBasicType() != EbtBool)
450             {
451                 padType->setPrecision(EbpLow);
452             }
453 
454             if (updateLayout)
455             {
456                 metalLayoutTotal += MetalLayoutOf(*padType);
457             }
458 
459             const Name name = idGen.createNewName("pad");
460             modifiedFields.push_back(
461                 new TField(padType, name.rawName(), kNoSourceLoc, name.symbolType()));
462             ++padFieldCount;
463         }
464 
465         std::reverse(modifiedFields.begin() + begin, modifiedFields.end());
466     }
467 
introducePacking()468     void introducePacking()
469     {
470         if (!config.allowPacking)
471         {
472             return;
473         }
474 
475         auto setUnpackedStorage = [](TType &type) {
476             TLayoutBlockStorage storage = type.getLayoutQualifier().blockStorage;
477             switch (storage)
478             {
479                 case TLayoutBlockStorage::EbsShared:
480                     storage = TLayoutBlockStorage::EbsStd140;
481                     break;
482                 case TLayoutBlockStorage::EbsPacked:
483                     storage = TLayoutBlockStorage::EbsStd430;
484                     break;
485                 case TLayoutBlockStorage::EbsStd140:
486                 case TLayoutBlockStorage::EbsStd430:
487                 case TLayoutBlockStorage::EbsUnspecified:
488                     break;
489             }
490             SetBlockStorage(type, storage);
491         };
492 
493         Layout glslLayoutTotal = Layout::Identity();
494         const size_t size      = modifiedFields.size();
495 
496         for (size_t i = 0; i < size; ++i)
497         {
498             TField &curr           = *modifiedFields[i];
499             TType &currType        = *curr.type();
500             const bool canBePacked = CanBePacked(currType);
501 
502             auto dontPack = [&]() {
503                 if (canBePacked)
504                 {
505                     setUnpackedStorage(currType);
506                 }
507                 glslLayoutTotal += GlslLayoutOf(currType);
508             };
509 
510             if (!CanBePacked(currType))
511             {
512                 dontPack();
513                 continue;
514             }
515 
516             const Layout packedGlslLayout           = GlslLayoutOf(currType);
517             const TLayoutBlockStorage packedStorage = currType.getLayoutQualifier().blockStorage;
518             setUnpackedStorage(currType);
519             const Layout unpackedGlslLayout = GlslLayoutOf(currType);
520             SetBlockStorage(currType, packedStorage);
521 
522             ASSERT(packedGlslLayout.sizeOf <= unpackedGlslLayout.sizeOf);
523             if (packedGlslLayout.sizeOf == unpackedGlslLayout.sizeOf)
524             {
525                 dontPack();
526                 continue;
527             }
528 
529             const size_t j = i + 1;
530             if (j == size)
531             {
532                 dontPack();
533                 break;
534             }
535 
536             const size_t pad            = unpackedGlslLayout.sizeOf - packedGlslLayout.sizeOf;
537             const TField &next          = *modifiedFields[j];
538             const Layout nextGlslLayout = GlslLayoutOf(*next.type());
539 
540             if (pad < nextGlslLayout.sizeOf)
541             {
542                 dontPack();
543                 continue;
544             }
545 
546             symbolEnv.markAsPacked(curr);
547             glslLayoutTotal += packedGlslLayout;
548             containsPacked = true;
549         }
550     }
551 
introducePadding()552     void introducePadding()
553     {
554         if (!config.allowPadding)
555         {
556             return;
557         }
558 
559         MetalLayoutOfConfig layoutConfig;
560         layoutConfig.disablePacking             = !config.allowPacking;
561         layoutConfig.assumeStructsAreTailPadded = true;
562 
563         TFieldList fields = std::move(modifiedFields);
564         ASSERT(!fields.empty());  // GLSL requires at least one member.
565 
566         const TField *const first = fields.front();
567 
568         for (TField *field : fields)
569         {
570             const TType &type = *field->type();
571 
572             const Layout glslLayout  = GlslLayoutOf(type);
573             const Layout metalLayout = MetalLayoutOf(type, layoutConfig);
574 
575             size_t prePadAmount = 0;
576             if (glslLayout.alignOf > metalLayout.alignOf && field != first)
577             {
578                 const size_t prePaddedSize = metalLayoutTotal.sizeOf;
579                 metalLayoutTotal.requireAlignment(glslLayout.alignOf, true);
580                 const size_t paddedSize = metalLayoutTotal.sizeOf;
581                 prePadAmount            = paddedSize - prePaddedSize;
582                 metalLayoutTotal += metalLayout;
583                 addPadding(prePadAmount, false);  // Note: requireAlignment() already updated layout
584             }
585             else
586             {
587                 metalLayoutTotal += metalLayout;
588             }
589 
590             modifiedFields.push_back(field);
591 
592             if (glslLayout.sizeOf > metalLayout.sizeOf && field != fields.back())
593             {
594                 const bool updateLayout = true;  // XXX: Correct?
595                 const size_t padAmount  = glslLayout.sizeOf - metalLayout.sizeOf;
596                 addPadding(padAmount, updateLayout);
597             }
598         }
599     }
600 
601   public:
602     TCompiler &mCompiler;
603     const ModifyStructConfig &config;
604     SymbolEnv &symbolEnv;
605 
606   private:
607     TFieldList &modifiedFields;
608     Layout metalLayoutTotal = Layout::Identity();
609     size_t padFieldCount    = 0;
610     bool containsPacked     = false;
611     bool finalized          = false;
612 
613     std::vector<PathItem> pathItems;
614 
615     int mSubfieldIndex = 0;
616 
617     std::vector<ConversionInfo> conversionInfos;
618     TSymbolTable &symbolTable;
619     IdGen &idGen;
620     ModifiedStructMachineries &outMachineries;
621     const bool isUBO;
622     const bool useAttributeAliasing;
623 };
624 
625 ////////////////////////////////////////////////////////////////////////////////
626 
627 using ModifyFunc = bool (*)(ConvertStructState &state,
628                             const TField &field,
629                             const TLayoutBlockStorage storage,
630                             const TLayoutMatrixPacking packing);
631 
632 bool ModifyRecursive(ConvertStructState &state,
633                      const TField &field,
634                      const TLayoutBlockStorage storage,
635                      const TLayoutMatrixPacking packing);
636 
IdentityModify(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)637 bool IdentityModify(ConvertStructState &state,
638                     const TField &field,
639                     const TLayoutBlockStorage storage,
640                     const TLayoutMatrixPacking packing)
641 {
642     const TType &type = *field.type();
643     state.addModifiedField(field, CloneType(type), storage, packing, nullptr);
644     state.addConversion([=](Access::Env &, OriginalAccess &o, ModifiedAccess &m) {
645         return Access{o, m};
646     });
647     return false;
648 }
649 
InlineStruct(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)650 bool InlineStruct(ConvertStructState &state,
651                   const TField &field,
652                   const TLayoutBlockStorage storage,
653                   const TLayoutMatrixPacking packing)
654 {
655     const TType &type              = *field.type();
656     const TStructure *substructure = state.symbolEnv.remap(type.getStruct());
657     if (!substructure)
658     {
659         return false;
660     }
661     if (type.isArray())
662     {
663         return false;
664     }
665     if (!state.config.inlineStruct(field))
666     {
667         return false;
668     }
669 
670     const TFieldList &subfields = substructure->fields();
671     for (const TField *subfield : subfields)
672     {
673         const TType &subtype                  = *subfield->type();
674         const TLayoutBlockStorage substorage  = Overlay(storage, subtype);
675         const TLayoutMatrixPacking subpacking = Overlay(packing, subtype);
676         ModifyRecursive(state, *subfield, substorage, subpacking);
677     }
678 
679     return true;
680 }
681 
RecurseStruct(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)682 bool RecurseStruct(ConvertStructState &state,
683                    const TField &field,
684                    const TLayoutBlockStorage storage,
685                    const TLayoutMatrixPacking packing)
686 {
687     const TType &type              = *field.type();
688     const TStructure *substructure = state.symbolEnv.remap(type.getStruct());
689     if (!substructure)
690     {
691         return false;
692     }
693     if (!state.config.recurseStruct(field))
694     {
695         return false;
696     }
697 
698     ModifiedStructMachinery machinery;
699     if (!state.recurse(*substructure, machinery, state.getIsUBO()))
700     {
701         return false;
702     }
703 
704     TType &newType = *new TType(machinery.modifiedStruct, false);
705     if (type.isArray())
706     {
707         newType.makeArrays(type.getArraySizes());
708     }
709 
710     TIntermFunctionDefinition *converter = machinery.getConverter(state.config.convertType);
711     ASSERT(converter);
712 
713     state.addModifiedField(field, newType, storage, packing, state.symbolEnv.isPointer(field));
714     if (state.symbolEnv.isPointer(field))
715     {
716         state.symbolEnv.removePointer(field);
717     }
718     state.addConversion(*converter->getFunction());
719 
720     return true;
721 }
722 
SplitMatrixColumns(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)723 bool SplitMatrixColumns(ConvertStructState &state,
724                         const TField &field,
725                         const TLayoutBlockStorage storage,
726                         const TLayoutMatrixPacking packing)
727 {
728     const TType &type = *field.type();
729     if (!type.isMatrix())
730     {
731         return false;
732     }
733 
734     if (!state.config.splitMatrixColumns(field))
735     {
736         return false;
737     }
738 
739     const uint8_t cols = type.getCols();
740     TType &rowType     = DropColumns(type);
741 
742     for (uint8_t c = 0; c < cols; ++c)
743     {
744         state.pushPath(c);
745 
746         state.addModifiedField(field, rowType, storage, packing, state.symbolEnv.isPointer(field));
747         if (state.symbolEnv.isPointer(field))
748         {
749             state.symbolEnv.removePointer(field);
750         }
751         state.addConversion([=](Access::Env &, OriginalAccess &o, ModifiedAccess &m) {
752             return Access{o, m};
753         });
754 
755         state.popPath();
756     }
757 
758     return true;
759 }
760 
SaturateMatrixRows(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)761 bool SaturateMatrixRows(ConvertStructState &state,
762                         const TField &field,
763                         const TLayoutBlockStorage storage,
764                         const TLayoutMatrixPacking packing)
765 {
766     const TType &type = *field.type();
767     if (!type.isMatrix())
768     {
769         return false;
770     }
771     const bool isRowMajor    = type.getLayoutQualifier().matrixPacking == EmpRowMajor;
772     const uint8_t rows       = type.getRows();
773     const uint8_t saturation = state.config.saturateMatrixRows(field);
774     if (saturation <= rows)
775     {
776         return false;
777     }
778 
779     const uint8_t cols = type.getCols();
780     TType &satType     = SetMatrixRowDim(type, saturation);
781     state.addModifiedField(field, satType, storage, packing, state.symbolEnv.isPointer(field));
782     if (state.symbolEnv.isPointer(field))
783     {
784         state.symbolEnv.removePointer(field);
785     }
786 
787     for (uint8_t c = 0; c < cols; ++c)
788     {
789         for (uint8_t r = 0; r < rows; ++r)
790         {
791             state.addConversion([=](Access::Env &, OriginalAccess &o, ModifiedAccess &m) {
792                 uint8_t firstModifiedIndex  = isRowMajor ? r : c;
793                 uint8_t secondModifiedIndex = isRowMajor ? c : r;
794                 auto &o_                    = AccessIndex(AccessIndex(o, c), r);
795                 auto &m_ = AccessIndex(AccessIndex(m, firstModifiedIndex), secondModifiedIndex);
796                 return Access{o_, m_};
797             });
798         }
799     }
800 
801     return true;
802 }
803 
TestBoolToUint(ConvertStructState & state,const TField & field)804 bool TestBoolToUint(ConvertStructState &state, const TField &field)
805 {
806     if (field.type()->getBasicType() != TBasicType::EbtBool)
807     {
808         return false;
809     }
810     if (!state.config.promoteBoolToUint(field))
811     {
812         return false;
813     }
814     return true;
815 }
816 
ConvertBoolToUint(ConvertType convertType,OriginalAccess & o,ModifiedAccess & m)817 Access ConvertBoolToUint(ConvertType convertType, OriginalAccess &o, ModifiedAccess &m)
818 {
819     auto coerce = [](TIntermTyped &to, TIntermTyped &from) -> TIntermTyped & {
820         return *TIntermAggregate::CreateConstructor(to.getType(), new TIntermSequence{&from});
821     };
822     switch (convertType)
823     {
824         case ConvertType::OriginalToModified:
825             return Access{coerce(m, o), m};
826         case ConvertType::ModifiedToOriginal:
827             return Access{o, coerce(o, m)};
828     }
829 }
830 
SaturateScalarOrVectorCommon(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing,const bool array)831 bool SaturateScalarOrVectorCommon(ConvertStructState &state,
832                                   const TField &field,
833                                   const TLayoutBlockStorage storage,
834                                   const TLayoutMatrixPacking packing,
835                                   const bool array)
836 {
837     const TType &type = *field.type();
838     if (type.isArray() != array)
839     {
840         return false;
841     }
842     if (!((type.isRank0() && HasScalarBasicType(type)) || type.isVector()))
843     {
844         return false;
845     }
846     const auto saturator =
847         array ? state.config.saturateScalarOrVectorArrays : state.config.saturateScalarOrVector;
848     const uint8_t dim        = type.getNominalSize();
849     const uint8_t saturation = saturator(field);
850     if (saturation <= dim)
851     {
852         return false;
853     }
854 
855     TType &satType        = SetVectorDim(type, saturation);
856     const bool boolToUint = TestBoolToUint(state, field);
857     if (boolToUint)
858     {
859         satType.setBasicType(TBasicType::EbtUInt);
860     }
861     state.addModifiedField(field, satType, storage, packing, state.symbolEnv.isPointer(field));
862     if (state.symbolEnv.isPointer(field))
863     {
864         state.symbolEnv.removePointer(field);
865     }
866 
867     for (uint8_t d = 0; d < dim; ++d)
868     {
869         state.addConversion([=](Access::Env &env, OriginalAccess &o, ModifiedAccess &m) {
870             auto &o_ = dim > 1 ? AccessIndex(o, d) : o;
871             auto &m_ = AccessIndex(m, d);
872             if (boolToUint)
873             {
874                 return ConvertBoolToUint(env.type, o_, m_);
875             }
876             else
877             {
878                 return Access{o_, m_};
879             }
880         });
881     }
882 
883     return true;
884 }
885 
SaturateScalarOrVectorArrays(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)886 bool SaturateScalarOrVectorArrays(ConvertStructState &state,
887                                   const TField &field,
888                                   const TLayoutBlockStorage storage,
889                                   const TLayoutMatrixPacking packing)
890 {
891     return SaturateScalarOrVectorCommon(state, field, storage, packing, true);
892 }
893 
SaturateScalarOrVector(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)894 bool SaturateScalarOrVector(ConvertStructState &state,
895                             const TField &field,
896                             const TLayoutBlockStorage storage,
897                             const TLayoutMatrixPacking packing)
898 {
899     return SaturateScalarOrVectorCommon(state, field, storage, packing, false);
900 }
901 
PromoteBoolToUint(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)902 bool PromoteBoolToUint(ConvertStructState &state,
903                        const TField &field,
904                        const TLayoutBlockStorage storage,
905                        const TLayoutMatrixPacking packing)
906 {
907     if (!TestBoolToUint(state, field))
908     {
909         return false;
910     }
911 
912     auto &promotedType = CloneType(*field.type());
913     promotedType.setBasicType(TBasicType::EbtUInt);
914     state.addModifiedField(field, promotedType, storage, packing, state.symbolEnv.isPointer(field));
915     if (state.symbolEnv.isPointer(field))
916     {
917         state.symbolEnv.removePointer(field);
918     }
919 
920     state.addConversion([=](Access::Env &env, OriginalAccess &o, ModifiedAccess &m) {
921         return ConvertBoolToUint(env.type, o, m);
922     });
923 
924     return true;
925 }
926 
ModifyCommon(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)927 bool ModifyCommon(ConvertStructState &state,
928                   const TField &field,
929                   const TLayoutBlockStorage storage,
930                   const TLayoutMatrixPacking packing)
931 {
932     ModifyFunc funcs[] = {
933         InlineStruct,                  //
934         RecurseStruct,                 //
935         SplitMatrixColumns,            //
936         SaturateMatrixRows,            //
937         SaturateScalarOrVectorArrays,  //
938         SaturateScalarOrVector,        //
939         PromoteBoolToUint,             //
940     };
941 
942     for (ModifyFunc func : funcs)
943     {
944         if (func(state, field, storage, packing))
945         {
946             return true;
947         }
948     }
949 
950     return IdentityModify(state, field, storage, packing);
951 }
952 
InlineArray(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)953 bool InlineArray(ConvertStructState &state,
954                  const TField &field,
955                  const TLayoutBlockStorage storage,
956                  const TLayoutMatrixPacking packing)
957 {
958     const TType &type = *field.type();
959     if (!type.isArray())
960     {
961         return false;
962     }
963     if (!state.config.inlineArray(field))
964     {
965         return false;
966     }
967 
968     const unsigned volume = type.getArraySizeProduct();
969     const bool isMultiDim = type.isArrayOfArrays();
970 
971     auto &innermostType = InnermostType(type);
972 
973     if (isMultiDim)
974     {
975         state.pushPath(FlattenArray());
976     }
977 
978     for (unsigned i = 0; i < volume; ++i)
979     {
980         state.pushPath(i);
981         TType setType(innermostType);
982         if (setType.getLayoutQualifier().locationsSpecified)
983         {
984             TLayoutQualifier qualifier(innermostType.getLayoutQualifier());
985             qualifier.location           = innermostType.getLayoutQualifier().location + i;
986             qualifier.locationsSpecified = 1;
987             setType.setLayoutQualifier(qualifier);
988         }
989         const TField innermostField(&setType, field.name(), field.line(), field.symbolType());
990         ModifyCommon(state, innermostField, storage, packing);
991         state.popPath();
992     }
993 
994     if (isMultiDim)
995     {
996         state.popPath();
997     }
998 
999     return true;
1000 }
1001 
ModifyRecursive(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)1002 bool ModifyRecursive(ConvertStructState &state,
1003                      const TField &field,
1004                      const TLayoutBlockStorage storage,
1005                      const TLayoutMatrixPacking packing)
1006 {
1007     state.pushPath(field);
1008 
1009     bool modified;
1010     if (InlineArray(state, field, storage, packing))
1011     {
1012         modified = true;
1013     }
1014     else
1015     {
1016         modified = ModifyCommon(state, field, storage, packing);
1017     }
1018 
1019     state.popPath();
1020 
1021     return modified;
1022 }
1023 
1024 }  // anonymous namespace
1025 
1026 ////////////////////////////////////////////////////////////////////////////////
1027 
TryCreateModifiedStruct(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen,const ModifyStructConfig & config,const TStructure & originalStruct,const Name & modifiedStructName,ModifiedStructMachineries & outMachineries,const bool isUBO,const bool allowPadding,const bool useAttributeAliasing)1028 bool sh::TryCreateModifiedStruct(TCompiler &compiler,
1029                                  SymbolEnv &symbolEnv,
1030                                  IdGen &idGen,
1031                                  const ModifyStructConfig &config,
1032                                  const TStructure &originalStruct,
1033                                  const Name &modifiedStructName,
1034                                  ModifiedStructMachineries &outMachineries,
1035                                  const bool isUBO,
1036                                  const bool allowPadding,
1037                                  const bool useAttributeAliasing)
1038 {
1039     ConvertStructState state(compiler, symbolEnv, idGen, config, outMachineries, isUBO,
1040                              useAttributeAliasing);
1041     size_t identicalFieldCount = 0;
1042 
1043     const TFieldList &originalFields = originalStruct.fields();
1044     for (TField *originalField : originalFields)
1045     {
1046         const TType &originalType          = *originalField->type();
1047         const TLayoutBlockStorage storage  = Overlay(config.initialBlockStorage, originalType);
1048         const TLayoutMatrixPacking packing = Overlay(config.initialMatrixPacking, originalType);
1049         if (!ModifyRecursive(state, *originalField, storage, packing))
1050         {
1051             ++identicalFieldCount;
1052         }
1053     }
1054 
1055     state.finalize(allowPadding);
1056 
1057     if (identicalFieldCount == originalFields.size() && !state.hasPacking() &&
1058         !state.hasPadding() && !useAttributeAliasing)
1059     {
1060         return false;
1061     }
1062     state.publish(originalStruct, modifiedStructName);
1063 
1064     return true;
1065 }
1066