1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/propagator.h"
16
17 #include <map>
18 #include <memory>
19 #include <vector>
20
21 #include "gmock/gmock.h"
22 #include "gtest/gtest.h"
23 #include "source/opt/build_module.h"
24 #include "source/opt/cfg.h"
25 #include "source/opt/ir_context.h"
26
27 namespace spvtools {
28 namespace opt {
29 namespace {
30
31 using ::testing::UnorderedElementsAre;
32
33 class PropagatorTest : public testing::Test {
34 protected:
TearDown()35 virtual void TearDown() {
36 ctx_.reset(nullptr);
37 values_.clear();
38 values_vec_.clear();
39 }
40
Assemble(const std::string & input)41 void Assemble(const std::string& input) {
42 ctx_ = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, input);
43 ASSERT_NE(nullptr, ctx_) << "Assembling failed for shader:\n"
44 << input << "\n";
45 }
46
Propagate(const SSAPropagator::VisitFunction & visit_fn)47 bool Propagate(const SSAPropagator::VisitFunction& visit_fn) {
48 SSAPropagator propagator(ctx_.get(), visit_fn);
49 bool retval = false;
50 for (auto& fn : *ctx_->module()) {
51 retval |= propagator.Run(&fn);
52 }
53 return retval;
54 }
55
GetValues()56 const std::vector<uint32_t>& GetValues() {
57 values_vec_.clear();
58 for (const auto& it : values_) {
59 values_vec_.push_back(it.second);
60 }
61 return values_vec_;
62 }
63
64 std::unique_ptr<IRContext> ctx_;
65 std::map<uint32_t, uint32_t> values_;
66 std::vector<uint32_t> values_vec_;
67 };
68
TEST_F(PropagatorTest,LocalPropagate)69 TEST_F(PropagatorTest, LocalPropagate) {
70 const std::string spv_asm = R"(
71 OpCapability Shader
72 %1 = OpExtInstImport "GLSL.std.450"
73 OpMemoryModel Logical GLSL450
74 OpEntryPoint Fragment %main "main" %outparm
75 OpExecutionMode %main OriginUpperLeft
76 OpSource GLSL 450
77 OpName %main "main"
78 OpName %x "x"
79 OpName %y "y"
80 OpName %z "z"
81 OpName %outparm "outparm"
82 OpDecorate %outparm Location 0
83 %void = OpTypeVoid
84 %3 = OpTypeFunction %void
85 %int = OpTypeInt 32 1
86 %_ptr_Function_int = OpTypePointer Function %int
87 %int_4 = OpConstant %int 4
88 %int_3 = OpConstant %int 3
89 %int_1 = OpConstant %int 1
90 %_ptr_Output_int = OpTypePointer Output %int
91 %outparm = OpVariable %_ptr_Output_int Output
92 %main = OpFunction %void None %3
93 %5 = OpLabel
94 %x = OpVariable %_ptr_Function_int Function
95 %y = OpVariable %_ptr_Function_int Function
96 %z = OpVariable %_ptr_Function_int Function
97 OpStore %x %int_4
98 OpStore %y %int_3
99 OpStore %z %int_1
100 %20 = OpLoad %int %z
101 OpStore %outparm %20
102 OpReturn
103 OpFunctionEnd
104 )";
105 Assemble(spv_asm);
106
107 const auto visit_fn = [this](Instruction* instr, BasicBlock** dest_bb) {
108 *dest_bb = nullptr;
109 if (instr->opcode() == spv::Op::OpStore) {
110 uint32_t lhs_id = instr->GetSingleWordOperand(0);
111 uint32_t rhs_id = instr->GetSingleWordOperand(1);
112 Instruction* rhs_def = ctx_->get_def_use_mgr()->GetDef(rhs_id);
113 if (rhs_def->opcode() == spv::Op::OpConstant) {
114 uint32_t val = rhs_def->GetSingleWordOperand(2);
115 values_[lhs_id] = val;
116 return SSAPropagator::kInteresting;
117 }
118 }
119 return SSAPropagator::kVarying;
120 };
121
122 EXPECT_TRUE(Propagate(visit_fn));
123 EXPECT_THAT(GetValues(), UnorderedElementsAre(4, 3, 1));
124 }
125
TEST_F(PropagatorTest,PropagateThroughPhis)126 TEST_F(PropagatorTest, PropagateThroughPhis) {
127 const std::string spv_asm = R"(
128 OpCapability Shader
129 %1 = OpExtInstImport "GLSL.std.450"
130 OpMemoryModel Logical GLSL450
131 OpEntryPoint Fragment %main "main" %x %outparm
132 OpExecutionMode %main OriginUpperLeft
133 OpSource GLSL 450
134 OpName %main "main"
135 OpName %x "x"
136 OpName %outparm "outparm"
137 OpDecorate %x Flat
138 OpDecorate %x Location 0
139 OpDecorate %outparm Location 0
140 %void = OpTypeVoid
141 %3 = OpTypeFunction %void
142 %int = OpTypeInt 32 1
143 %bool = OpTypeBool
144 %_ptr_Function_int = OpTypePointer Function %int
145 %int_4 = OpConstant %int 4
146 %int_3 = OpConstant %int 3
147 %int_1 = OpConstant %int 1
148 %_ptr_Input_int = OpTypePointer Input %int
149 %x = OpVariable %_ptr_Input_int Input
150 %_ptr_Output_int = OpTypePointer Output %int
151 %outparm = OpVariable %_ptr_Output_int Output
152 %main = OpFunction %void None %3
153 %4 = OpLabel
154 %5 = OpLoad %int %x
155 %6 = OpSGreaterThan %bool %5 %int_3
156 OpSelectionMerge %25 None
157 OpBranchConditional %6 %22 %23
158 %22 = OpLabel
159 %7 = OpLoad %int %int_4
160 OpBranch %25
161 %23 = OpLabel
162 %8 = OpLoad %int %int_4
163 OpBranch %25
164 %25 = OpLabel
165 %35 = OpPhi %int %7 %22 %8 %23
166 OpStore %outparm %35
167 OpReturn
168 OpFunctionEnd
169 )";
170
171 Assemble(spv_asm);
172
173 Instruction* phi_instr = nullptr;
174 const auto visit_fn = [this, &phi_instr](Instruction* instr,
175 BasicBlock** dest_bb) {
176 *dest_bb = nullptr;
177 if (instr->opcode() == spv::Op::OpLoad) {
178 uint32_t rhs_id = instr->GetSingleWordOperand(2);
179 Instruction* rhs_def = ctx_->get_def_use_mgr()->GetDef(rhs_id);
180 if (rhs_def->opcode() == spv::Op::OpConstant) {
181 uint32_t val = rhs_def->GetSingleWordOperand(2);
182 values_[instr->result_id()] = val;
183 return SSAPropagator::kInteresting;
184 }
185 } else if (instr->opcode() == spv::Op::OpPhi) {
186 phi_instr = instr;
187 SSAPropagator::PropStatus retval;
188 for (uint32_t i = 2; i < instr->NumOperands(); i += 2) {
189 uint32_t phi_arg_id = instr->GetSingleWordOperand(i);
190 auto it = values_.find(phi_arg_id);
191 if (it != values_.end()) {
192 EXPECT_EQ(it->second, 4u);
193 retval = SSAPropagator::kInteresting;
194 values_[instr->result_id()] = it->second;
195 } else {
196 retval = SSAPropagator::kNotInteresting;
197 break;
198 }
199 }
200 return retval;
201 }
202
203 return SSAPropagator::kVarying;
204 };
205
206 EXPECT_TRUE(Propagate(visit_fn));
207
208 // The propagator should've concluded that the Phi instruction has a constant
209 // value of 4.
210 EXPECT_NE(phi_instr, nullptr);
211 EXPECT_EQ(values_[phi_instr->result_id()], 4u);
212
213 EXPECT_THAT(GetValues(), UnorderedElementsAre(4u, 4u, 4u));
214 }
215
216 } // namespace
217 } // namespace opt
218 } // namespace spvtools
219