xref: /aosp_15_r20/external/angle/src/compiler/translator/tree_ops/msl/ReduceInterfaceBlocks.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <algorithm>
8 #include <unordered_map>
9 
10 #include "compiler/translator/IntermRebuild.h"
11 #include "compiler/translator/SymbolTable.h"
12 #include "compiler/translator/msl/AstHelpers.h"
13 #include "compiler/translator/msl/TranslatorMSL.h"
14 #include "compiler/translator/tree_ops/SeparateDeclarations.h"
15 #include "compiler/translator/tree_ops/msl/ReduceInterfaceBlocks.h"
16 
17 using namespace sh;
18 
19 ////////////////////////////////////////////////////////////////////////////////
20 
21 namespace
22 {
23 
24 class Reducer : public TIntermRebuild
25 {
26     std::unordered_map<const TInterfaceBlock *, const TVariable *> mLiftedMap;
27     std::unordered_map<const TVariable *, const TVariable *> mInstanceMap;
28     IdGen &mIdGen;
29 
30   public:
Reducer(TCompiler & compiler,IdGen & idGen)31     Reducer(TCompiler &compiler, IdGen &idGen)
32         : TIntermRebuild(compiler, true, false), mIdGen(idGen)
33     {}
34 
visitDeclarationPre(TIntermDeclaration & declNode)35     PreResult visitDeclarationPre(TIntermDeclaration &declNode) override
36     {
37         ASSERT(declNode.getChildCount() == 1);
38         TIntermNode &node = *declNode.getChildNode(0);
39 
40         if (TIntermSymbol *symbolNode = node.getAsSymbolNode())
41         {
42             const TVariable &var        = symbolNode->variable();
43             const TType &type           = var.getType();
44             const SymbolType symbolType = var.symbolType();
45             if (const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock())
46             {
47                 if (symbolType == SymbolType::Empty)
48                 {
49                     // Create instance variable
50                     auto &structure =
51                         *new TStructure(&mSymbolTable, interfaceBlock->name(),
52                                         &interfaceBlock->fields(), interfaceBlock->symbolType());
53                     auto &structVar = CreateStructTypeVariable(mSymbolTable, structure);
54 
55                     auto &instanceVar = CreateInstanceVariable(
56                         mSymbolTable, structure, mIdGen.createNewName(interfaceBlock->name()),
57                         TQualifier::EvqBuffer, &type.getArraySizes());
58                     mLiftedMap[interfaceBlock] = &instanceVar;
59 
60                     TIntermNode *replacements[] = {
61                         new TIntermDeclaration{new TIntermSymbol(&structVar)},
62                         new TIntermDeclaration{new TIntermSymbol(&instanceVar)}};
63                     return PreResult::Multi(std::begin(replacements), std::end(replacements));
64                 }
65                 else
66                 {
67                     ASSERT(type.getQualifier() == TQualifier::EvqUniform);
68 
69                     auto &structure =
70                         *new TStructure(&mSymbolTable, interfaceBlock->name(),
71                                         &interfaceBlock->fields(), interfaceBlock->symbolType());
72                     auto &structVar = CreateStructTypeVariable(mSymbolTable, structure);
73                     auto &instanceVar =
74                         CreateInstanceVariable(mSymbolTable, structure, Name(var),
75                                                TQualifier::EvqBuffer, &type.getArraySizes());
76 
77                     mInstanceMap[&var] = &instanceVar;
78 
79                     TIntermNode *replacements[] = {
80                         new TIntermDeclaration{new TIntermSymbol(&structVar)},
81                         new TIntermDeclaration{new TIntermSymbol(&instanceVar)}};
82                     return PreResult::Multi(std::begin(replacements), std::end(replacements));
83                 }
84             }
85         }
86 
87         return {declNode, VisitBits::Both};
88     }
89 
visitSymbolPre(TIntermSymbol & symbolNode)90     PreResult visitSymbolPre(TIntermSymbol &symbolNode) override
91     {
92         const TVariable &var = symbolNode.variable();
93         {
94             auto it = mInstanceMap.find(&var);
95             if (it != mInstanceMap.end())
96             {
97                 return *new TIntermSymbol(it->second);
98             }
99         }
100         if (const TInterfaceBlock *ib = var.getType().getInterfaceBlock())
101         {
102             auto it = mLiftedMap.find(ib);
103             if (it != mLiftedMap.end())
104             {
105                 return AccessField(*(it->second), Name(var));
106             }
107         }
108         return symbolNode;
109     }
110 };
111 
112 }  // anonymous namespace
113 
114 ////////////////////////////////////////////////////////////////////////////////
115 
ReduceInterfaceBlocks(TCompiler & compiler,TIntermBlock & root,IdGen & idGen)116 bool sh::ReduceInterfaceBlocks(TCompiler &compiler, TIntermBlock &root, IdGen &idGen)
117 {
118     Reducer reducer(compiler, idGen);
119     if (!reducer.rebuildRoot(root))
120     {
121         return false;
122     }
123 
124     if (!SeparateDeclarations(compiler, root, false))
125     {
126         return false;
127     }
128 
129     return true;
130 }
131