1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <memory>
16 #include <vector>
17 
18 #include "function_utils.h"
19 #include "gmock/gmock.h"
20 #include "gtest/gtest.h"
21 #include "source/opt/build_module.h"
22 #include "source/opt/ir_context.h"
23 
24 namespace spvtools {
25 namespace opt {
26 namespace {
27 
28 using ::testing::Eq;
29 
TEST(FunctionTest,HasEarlyReturn)30 TEST(FunctionTest, HasEarlyReturn) {
31   std::string shader = R"(
32           OpCapability Shader
33      %1 = OpExtInstImport "GLSL.std.450"
34           OpMemoryModel Logical GLSL450
35           OpEntryPoint Vertex %6 "main"
36 
37 ; Types
38      %2 = OpTypeBool
39      %3 = OpTypeVoid
40      %4 = OpTypeFunction %3
41 
42 ; Constants
43      %5 = OpConstantTrue %2
44 
45 ; main function without early return
46      %6 = OpFunction %3 None %4
47      %7 = OpLabel
48           OpBranch %8
49      %8 = OpLabel
50           OpBranch %9
51      %9 = OpLabel
52           OpBranch %10
53     %10 = OpLabel
54           OpReturn
55           OpFunctionEnd
56 
57 ; function with early return
58     %11 = OpFunction %3 None %4
59     %12 = OpLabel
60           OpSelectionMerge %15 None
61           OpBranchConditional %5 %13 %14
62     %13 = OpLabel
63           OpReturn
64     %14 = OpLabel
65           OpBranch %15
66     %15 = OpLabel
67           OpReturn
68           OpFunctionEnd
69   )";
70 
71   const auto context =
72       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, shader,
73                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
74 
75   // Tests |function| without early return.
76   auto* function = spvtest::GetFunction(context->module(), 6);
77   ASSERT_FALSE(function->HasEarlyReturn());
78 
79   // Tests |function| with early return.
80   function = spvtest::GetFunction(context->module(), 11);
81   ASSERT_TRUE(function->HasEarlyReturn());
82 }
83 
TEST(FunctionTest,IsNotRecursive)84 TEST(FunctionTest, IsNotRecursive) {
85   const std::string text = R"(
86 OpCapability Shader
87 OpMemoryModel Logical GLSL450
88 OpEntryPoint Fragment %1 "main"
89 OpExecutionMode %1 OriginUpperLeft
90 OpDecorate %2 DescriptorSet 439418829
91 %void = OpTypeVoid
92 %4 = OpTypeFunction %void
93 %float = OpTypeFloat 32
94 %_struct_6 = OpTypeStruct %float %float
95 %7 = OpTypeFunction %_struct_6
96 %1 = OpFunction %void Pure|Const %4
97 %8 = OpLabel
98 %2 = OpFunctionCall %_struct_6 %9
99 OpKill
100 OpFunctionEnd
101 %9 = OpFunction %_struct_6 None %7
102 %10 = OpLabel
103 %11 = OpFunctionCall %_struct_6 %12
104 OpUnreachable
105 OpFunctionEnd
106 %12 = OpFunction %_struct_6 None %7
107 %13 = OpLabel
108 OpUnreachable
109 OpFunctionEnd
110 )";
111 
112   std::unique_ptr<IRContext> ctx =
113       spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
114                             SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
115   auto* func = spvtest::GetFunction(ctx->module(), 9);
116   EXPECT_FALSE(func->IsRecursive());
117 
118   func = spvtest::GetFunction(ctx->module(), 12);
119   EXPECT_FALSE(func->IsRecursive());
120 }
121 
TEST(FunctionTest,IsDirectlyRecursive)122 TEST(FunctionTest, IsDirectlyRecursive) {
123   const std::string text = R"(
124 OpCapability Shader
125 OpMemoryModel Logical GLSL450
126 OpEntryPoint Fragment %1 "main"
127 OpExecutionMode %1 OriginUpperLeft
128 OpDecorate %2 DescriptorSet 439418829
129 %void = OpTypeVoid
130 %4 = OpTypeFunction %void
131 %float = OpTypeFloat 32
132 %_struct_6 = OpTypeStruct %float %float
133 %7 = OpTypeFunction %_struct_6
134 %1 = OpFunction %void Pure|Const %4
135 %8 = OpLabel
136 %2 = OpFunctionCall %_struct_6 %9
137 OpKill
138 OpFunctionEnd
139 %9 = OpFunction %_struct_6 None %7
140 %10 = OpLabel
141 %11 = OpFunctionCall %_struct_6 %9
142 OpUnreachable
143 OpFunctionEnd
144 )";
145 
146   std::unique_ptr<IRContext> ctx =
147       spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
148                             SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
149   auto* func = spvtest::GetFunction(ctx->module(), 9);
150   EXPECT_TRUE(func->IsRecursive());
151 }
152 
TEST(FunctionTest,IsIndirectlyRecursive)153 TEST(FunctionTest, IsIndirectlyRecursive) {
154   const std::string text = R"(
155 OpCapability Shader
156 OpMemoryModel Logical GLSL450
157 OpEntryPoint Fragment %1 "main"
158 OpExecutionMode %1 OriginUpperLeft
159 OpDecorate %2 DescriptorSet 439418829
160 %void = OpTypeVoid
161 %4 = OpTypeFunction %void
162 %float = OpTypeFloat 32
163 %_struct_6 = OpTypeStruct %float %float
164 %7 = OpTypeFunction %_struct_6
165 %1 = OpFunction %void Pure|Const %4
166 %8 = OpLabel
167 %2 = OpFunctionCall %_struct_6 %9
168 OpKill
169 OpFunctionEnd
170 %9 = OpFunction %_struct_6 None %7
171 %10 = OpLabel
172 %11 = OpFunctionCall %_struct_6 %12
173 OpUnreachable
174 OpFunctionEnd
175 %12 = OpFunction %_struct_6 None %7
176 %13 = OpLabel
177 %14 = OpFunctionCall %_struct_6 %9
178 OpUnreachable
179 OpFunctionEnd
180 )";
181 
182   std::unique_ptr<IRContext> ctx =
183       spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
184                             SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
185   auto* func = spvtest::GetFunction(ctx->module(), 9);
186   EXPECT_TRUE(func->IsRecursive());
187 
188   func = spvtest::GetFunction(ctx->module(), 12);
189   EXPECT_TRUE(func->IsRecursive());
190 }
191 
TEST(FunctionTest,IsNotRecuriseCallingRecursive)192 TEST(FunctionTest, IsNotRecuriseCallingRecursive) {
193   const std::string text = R"(
194 OpCapability Shader
195 OpMemoryModel Logical GLSL450
196 OpEntryPoint Fragment %1 "main"
197 OpExecutionMode %1 OriginUpperLeft
198 OpDecorate %2 DescriptorSet 439418829
199 %void = OpTypeVoid
200 %4 = OpTypeFunction %void
201 %float = OpTypeFloat 32
202 %_struct_6 = OpTypeStruct %float %float
203 %7 = OpTypeFunction %_struct_6
204 %1 = OpFunction %void Pure|Const %4
205 %8 = OpLabel
206 %2 = OpFunctionCall %_struct_6 %9
207 OpKill
208 OpFunctionEnd
209 %9 = OpFunction %_struct_6 None %7
210 %10 = OpLabel
211 %11 = OpFunctionCall %_struct_6 %9
212 OpUnreachable
213 OpFunctionEnd
214 )";
215 
216   std::unique_ptr<IRContext> ctx =
217       spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
218                             SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
219   auto* func = spvtest::GetFunction(ctx->module(), 1);
220   EXPECT_FALSE(func->IsRecursive());
221 }
222 
TEST(FunctionTest,NonSemanticInfoSkipIteration)223 TEST(FunctionTest, NonSemanticInfoSkipIteration) {
224   const std::string text = R"(
225 OpCapability Shader
226 OpCapability Linkage
227 OpExtension "SPV_KHR_non_semantic_info"
228 %1 = OpExtInstImport "NonSemantic.Test"
229 OpMemoryModel Logical GLSL450
230 %2 = OpTypeVoid
231 %3 = OpTypeFunction %2
232 %4 = OpFunction %2 None %3
233 %5 = OpLabel
234 %6 = OpExtInst %2 %1 1
235 OpReturn
236 OpFunctionEnd
237 %7 = OpExtInst %2 %1 2
238 %8 = OpExtInst %2 %1 3
239 )";
240 
241   std::unique_ptr<IRContext> ctx =
242       spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
243                             SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
244   auto* func = spvtest::GetFunction(ctx->module(), 4);
245   ASSERT_TRUE(func != nullptr);
246   std::unordered_set<uint32_t> non_semantic_ids;
247   func->ForEachInst(
248       [&non_semantic_ids](const Instruction* inst) {
249         if (inst->opcode() == spv::Op::OpExtInst) {
250           non_semantic_ids.insert(inst->result_id());
251         }
252       },
253       true, false);
254 
255   EXPECT_EQ(1, non_semantic_ids.count(6));
256   EXPECT_EQ(0, non_semantic_ids.count(7));
257   EXPECT_EQ(0, non_semantic_ids.count(8));
258 }
259 
TEST(FunctionTest,NonSemanticInfoIncludeIteration)260 TEST(FunctionTest, NonSemanticInfoIncludeIteration) {
261   const std::string text = R"(
262 OpCapability Shader
263 OpCapability Linkage
264 OpExtension "SPV_KHR_non_semantic_info"
265 %1 = OpExtInstImport "NonSemantic.Test"
266 OpMemoryModel Logical GLSL450
267 %2 = OpTypeVoid
268 %3 = OpTypeFunction %2
269 %4 = OpFunction %2 None %3
270 %5 = OpLabel
271 %6 = OpExtInst %2 %1 1
272 OpReturn
273 OpFunctionEnd
274 %7 = OpExtInst %2 %1 2
275 %8 = OpExtInst %2 %1 3
276 )";
277 
278   std::unique_ptr<IRContext> ctx =
279       spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
280                             SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
281   auto* func = spvtest::GetFunction(ctx->module(), 4);
282   ASSERT_TRUE(func != nullptr);
283   std::unordered_set<uint32_t> non_semantic_ids;
284   func->ForEachInst(
285       [&non_semantic_ids](const Instruction* inst) {
286         if (inst->opcode() == spv::Op::OpExtInst) {
287           non_semantic_ids.insert(inst->result_id());
288         }
289       },
290       true, true);
291 
292   EXPECT_EQ(1, non_semantic_ids.count(6));
293   EXPECT_EQ(1, non_semantic_ids.count(7));
294   EXPECT_EQ(1, non_semantic_ids.count(8));
295 }
296 
TEST(FunctionTest,ReorderBlocksinStructuredOrder)297 TEST(FunctionTest, ReorderBlocksinStructuredOrder) {
298   // The spir-v has the basic block in a random order.  We want to reorder them
299   // in structured order.
300   const std::string text = R"(
301                OpCapability Shader
302                OpMemoryModel Logical GLSL450
303                OpEntryPoint Fragment %100 "PSMain"
304                OpExecutionMode %PSMain OriginUpperLeft
305                OpSource HLSL 600
306         %int = OpTypeInt 32 1
307        %void = OpTypeVoid
308          %19 = OpTypeFunction %void
309        %bool = OpTypeBool
310 %undef_bool = OpUndef %bool
311 %undef_int = OpUndef %int
312         %100 = OpFunction %void None %19
313           %11 = OpLabel
314                OpSelectionMerge %10 None
315                OpSwitch %undef_int %3 0 %2 10 %1
316           %2 = OpLabel
317                OpReturn
318           %7 = OpLabel
319                OpBranch %8
320           %3 = OpLabel
321                OpBranch %4
322          %10 = OpLabel
323                OpReturn
324           %9 = OpLabel
325                OpBranch %10
326           %8 = OpLabel
327                OpBranch %4
328           %4 = OpLabel
329                OpLoopMerge %9 %8 None
330                OpBranchConditional %undef_bool %5 %9
331           %1 = OpLabel
332                OpReturn
333           %6 = OpLabel
334                OpBranch %7
335           %5 = OpLabel
336                OpSelectionMerge %7 None
337                OpBranchConditional %undef_bool %6 %7
338                OpFunctionEnd
339 )";
340 
341   std::unique_ptr<IRContext> ctx =
342       spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
343                             SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
344   ASSERT_TRUE(ctx);
345   auto* func = spvtest::GetFunction(ctx->module(), 100);
346   ASSERT_TRUE(func);
347   func->ReorderBasicBlocksInStructuredOrder();
348 
349   auto first_block = func->begin();
350   auto bb = first_block;
351   for (++bb; bb != func->end(); ++bb) {
352     EXPECT_EQ(bb->id(), (bb - first_block));
353   }
354 }
355 
356 }  // namespace
357 }  // namespace opt
358 }  // namespace spvtools
359