1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/propagator.h"
16 
17 #include <map>
18 #include <memory>
19 #include <vector>
20 
21 #include "gmock/gmock.h"
22 #include "gtest/gtest.h"
23 #include "source/opt/build_module.h"
24 #include "source/opt/cfg.h"
25 #include "source/opt/ir_context.h"
26 
27 namespace spvtools {
28 namespace opt {
29 namespace {
30 
31 using ::testing::UnorderedElementsAre;
32 
33 class PropagatorTest : public testing::Test {
34  protected:
TearDown()35   virtual void TearDown() {
36     ctx_.reset(nullptr);
37     values_.clear();
38     values_vec_.clear();
39   }
40 
Assemble(const std::string & input)41   void Assemble(const std::string& input) {
42     ctx_ = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, input);
43     ASSERT_NE(nullptr, ctx_) << "Assembling failed for shader:\n"
44                              << input << "\n";
45   }
46 
Propagate(const SSAPropagator::VisitFunction & visit_fn)47   bool Propagate(const SSAPropagator::VisitFunction& visit_fn) {
48     SSAPropagator propagator(ctx_.get(), visit_fn);
49     bool retval = false;
50     for (auto& fn : *ctx_->module()) {
51       retval |= propagator.Run(&fn);
52     }
53     return retval;
54   }
55 
GetValues()56   const std::vector<uint32_t>& GetValues() {
57     values_vec_.clear();
58     for (const auto& it : values_) {
59       values_vec_.push_back(it.second);
60     }
61     return values_vec_;
62   }
63 
64   std::unique_ptr<IRContext> ctx_;
65   std::map<uint32_t, uint32_t> values_;
66   std::vector<uint32_t> values_vec_;
67 };
68 
TEST_F(PropagatorTest,LocalPropagate)69 TEST_F(PropagatorTest, LocalPropagate) {
70   const std::string spv_asm = R"(
71                OpCapability Shader
72           %1 = OpExtInstImport "GLSL.std.450"
73                OpMemoryModel Logical GLSL450
74                OpEntryPoint Fragment %main "main" %outparm
75                OpExecutionMode %main OriginUpperLeft
76                OpSource GLSL 450
77                OpName %main "main"
78                OpName %x "x"
79                OpName %y "y"
80                OpName %z "z"
81                OpName %outparm "outparm"
82                OpDecorate %outparm Location 0
83        %void = OpTypeVoid
84           %3 = OpTypeFunction %void
85         %int = OpTypeInt 32 1
86 %_ptr_Function_int = OpTypePointer Function %int
87       %int_4 = OpConstant %int 4
88       %int_3 = OpConstant %int 3
89       %int_1 = OpConstant %int 1
90 %_ptr_Output_int = OpTypePointer Output %int
91     %outparm = OpVariable %_ptr_Output_int Output
92        %main = OpFunction %void None %3
93           %5 = OpLabel
94           %x = OpVariable %_ptr_Function_int Function
95           %y = OpVariable %_ptr_Function_int Function
96           %z = OpVariable %_ptr_Function_int Function
97                OpStore %x %int_4
98                OpStore %y %int_3
99                OpStore %z %int_1
100          %20 = OpLoad %int %z
101                OpStore %outparm %20
102                OpReturn
103                OpFunctionEnd
104                )";
105   Assemble(spv_asm);
106 
107   const auto visit_fn = [this](Instruction* instr, BasicBlock** dest_bb) {
108     *dest_bb = nullptr;
109     if (instr->opcode() == spv::Op::OpStore) {
110       uint32_t lhs_id = instr->GetSingleWordOperand(0);
111       uint32_t rhs_id = instr->GetSingleWordOperand(1);
112       Instruction* rhs_def = ctx_->get_def_use_mgr()->GetDef(rhs_id);
113       if (rhs_def->opcode() == spv::Op::OpConstant) {
114         uint32_t val = rhs_def->GetSingleWordOperand(2);
115         values_[lhs_id] = val;
116         return SSAPropagator::kInteresting;
117       }
118     }
119     return SSAPropagator::kVarying;
120   };
121 
122   EXPECT_TRUE(Propagate(visit_fn));
123   EXPECT_THAT(GetValues(), UnorderedElementsAre(4, 3, 1));
124 }
125 
TEST_F(PropagatorTest,PropagateThroughPhis)126 TEST_F(PropagatorTest, PropagateThroughPhis) {
127   const std::string spv_asm = R"(
128                OpCapability Shader
129           %1 = OpExtInstImport "GLSL.std.450"
130                OpMemoryModel Logical GLSL450
131                OpEntryPoint Fragment %main "main" %x %outparm
132                OpExecutionMode %main OriginUpperLeft
133                OpSource GLSL 450
134                OpName %main "main"
135                OpName %x "x"
136                OpName %outparm "outparm"
137                OpDecorate %x Flat
138                OpDecorate %x Location 0
139                OpDecorate %outparm Location 0
140        %void = OpTypeVoid
141           %3 = OpTypeFunction %void
142         %int = OpTypeInt 32 1
143        %bool = OpTypeBool
144 %_ptr_Function_int = OpTypePointer Function %int
145       %int_4 = OpConstant %int 4
146       %int_3 = OpConstant %int 3
147       %int_1 = OpConstant %int 1
148 %_ptr_Input_int = OpTypePointer Input %int
149           %x = OpVariable %_ptr_Input_int Input
150 %_ptr_Output_int = OpTypePointer Output %int
151     %outparm = OpVariable %_ptr_Output_int Output
152        %main = OpFunction %void None %3
153           %4 = OpLabel
154           %5 = OpLoad %int %x
155           %6 = OpSGreaterThan %bool %5 %int_3
156                OpSelectionMerge %25 None
157                OpBranchConditional %6 %22 %23
158          %22 = OpLabel
159           %7 = OpLoad %int %int_4
160                OpBranch %25
161          %23 = OpLabel
162           %8 = OpLoad %int %int_4
163                OpBranch %25
164          %25 = OpLabel
165          %35 = OpPhi %int %7 %22 %8 %23
166                OpStore %outparm %35
167                OpReturn
168                OpFunctionEnd
169                )";
170 
171   Assemble(spv_asm);
172 
173   Instruction* phi_instr = nullptr;
174   const auto visit_fn = [this, &phi_instr](Instruction* instr,
175                                            BasicBlock** dest_bb) {
176     *dest_bb = nullptr;
177     if (instr->opcode() == spv::Op::OpLoad) {
178       uint32_t rhs_id = instr->GetSingleWordOperand(2);
179       Instruction* rhs_def = ctx_->get_def_use_mgr()->GetDef(rhs_id);
180       if (rhs_def->opcode() == spv::Op::OpConstant) {
181         uint32_t val = rhs_def->GetSingleWordOperand(2);
182         values_[instr->result_id()] = val;
183         return SSAPropagator::kInteresting;
184       }
185     } else if (instr->opcode() == spv::Op::OpPhi) {
186       phi_instr = instr;
187       SSAPropagator::PropStatus retval;
188       for (uint32_t i = 2; i < instr->NumOperands(); i += 2) {
189         uint32_t phi_arg_id = instr->GetSingleWordOperand(i);
190         auto it = values_.find(phi_arg_id);
191         if (it != values_.end()) {
192           EXPECT_EQ(it->second, 4u);
193           retval = SSAPropagator::kInteresting;
194           values_[instr->result_id()] = it->second;
195         } else {
196           retval = SSAPropagator::kNotInteresting;
197           break;
198         }
199       }
200       return retval;
201     }
202 
203     return SSAPropagator::kVarying;
204   };
205 
206   EXPECT_TRUE(Propagate(visit_fn));
207 
208   // The propagator should've concluded that the Phi instruction has a constant
209   // value of 4.
210   EXPECT_NE(phi_instr, nullptr);
211   EXPECT_EQ(values_[phi_instr->result_id()], 4u);
212 
213   EXPECT_THAT(GetValues(), UnorderedElementsAre(4u, 4u, 4u));
214 }
215 
216 }  // namespace
217 }  // namespace opt
218 }  // namespace spvtools
219