1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <algorithm>
8 #include <unordered_map>
9
10 #include "compiler/translator/IntermRebuild.h"
11 #include "compiler/translator/SymbolTable.h"
12 #include "compiler/translator/msl/AstHelpers.h"
13 #include "compiler/translator/msl/TranslatorMSL.h"
14 #include "compiler/translator/tree_ops/SeparateDeclarations.h"
15 #include "compiler/translator/tree_ops/msl/ReduceInterfaceBlocks.h"
16
17 using namespace sh;
18
19 ////////////////////////////////////////////////////////////////////////////////
20
21 namespace
22 {
23
24 class Reducer : public TIntermRebuild
25 {
26 std::unordered_map<const TInterfaceBlock *, const TVariable *> mLiftedMap;
27 std::unordered_map<const TVariable *, const TVariable *> mInstanceMap;
28 IdGen &mIdGen;
29
30 public:
Reducer(TCompiler & compiler,IdGen & idGen)31 Reducer(TCompiler &compiler, IdGen &idGen)
32 : TIntermRebuild(compiler, true, false), mIdGen(idGen)
33 {}
34
visitDeclarationPre(TIntermDeclaration & declNode)35 PreResult visitDeclarationPre(TIntermDeclaration &declNode) override
36 {
37 ASSERT(declNode.getChildCount() == 1);
38 TIntermNode &node = *declNode.getChildNode(0);
39
40 if (TIntermSymbol *symbolNode = node.getAsSymbolNode())
41 {
42 const TVariable &var = symbolNode->variable();
43 const TType &type = var.getType();
44 const SymbolType symbolType = var.symbolType();
45 if (const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock())
46 {
47 if (symbolType == SymbolType::Empty)
48 {
49 // Create instance variable
50 auto &structure =
51 *new TStructure(&mSymbolTable, interfaceBlock->name(),
52 &interfaceBlock->fields(), interfaceBlock->symbolType());
53 auto &structVar = CreateStructTypeVariable(mSymbolTable, structure);
54
55 auto &instanceVar = CreateInstanceVariable(
56 mSymbolTable, structure, mIdGen.createNewName(interfaceBlock->name()),
57 TQualifier::EvqBuffer, &type.getArraySizes());
58 mLiftedMap[interfaceBlock] = &instanceVar;
59
60 TIntermNode *replacements[] = {
61 new TIntermDeclaration{new TIntermSymbol(&structVar)},
62 new TIntermDeclaration{new TIntermSymbol(&instanceVar)}};
63 return PreResult::Multi(std::begin(replacements), std::end(replacements));
64 }
65 else
66 {
67 ASSERT(type.getQualifier() == TQualifier::EvqUniform);
68
69 auto &structure =
70 *new TStructure(&mSymbolTable, interfaceBlock->name(),
71 &interfaceBlock->fields(), interfaceBlock->symbolType());
72 auto &structVar = CreateStructTypeVariable(mSymbolTable, structure);
73 auto &instanceVar =
74 CreateInstanceVariable(mSymbolTable, structure, Name(var),
75 TQualifier::EvqBuffer, &type.getArraySizes());
76
77 mInstanceMap[&var] = &instanceVar;
78
79 TIntermNode *replacements[] = {
80 new TIntermDeclaration{new TIntermSymbol(&structVar)},
81 new TIntermDeclaration{new TIntermSymbol(&instanceVar)}};
82 return PreResult::Multi(std::begin(replacements), std::end(replacements));
83 }
84 }
85 }
86
87 return {declNode, VisitBits::Both};
88 }
89
visitSymbolPre(TIntermSymbol & symbolNode)90 PreResult visitSymbolPre(TIntermSymbol &symbolNode) override
91 {
92 const TVariable &var = symbolNode.variable();
93 {
94 auto it = mInstanceMap.find(&var);
95 if (it != mInstanceMap.end())
96 {
97 return *new TIntermSymbol(it->second);
98 }
99 }
100 if (const TInterfaceBlock *ib = var.getType().getInterfaceBlock())
101 {
102 auto it = mLiftedMap.find(ib);
103 if (it != mLiftedMap.end())
104 {
105 return AccessField(*(it->second), Name(var));
106 }
107 }
108 return symbolNode;
109 }
110 };
111
112 } // anonymous namespace
113
114 ////////////////////////////////////////////////////////////////////////////////
115
ReduceInterfaceBlocks(TCompiler & compiler,TIntermBlock & root,IdGen & idGen)116 bool sh::ReduceInterfaceBlocks(TCompiler &compiler, TIntermBlock &root, IdGen &idGen)
117 {
118 Reducer reducer(compiler, idGen);
119 if (!reducer.rebuildRoot(root))
120 {
121 return false;
122 }
123
124 if (!SeparateDeclarations(compiler, root, false))
125 {
126 return false;
127 }
128
129 return true;
130 }
131