1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <memory>
16 #include <vector>
17
18 #include "function_utils.h"
19 #include "gmock/gmock.h"
20 #include "gtest/gtest.h"
21 #include "source/opt/build_module.h"
22 #include "source/opt/ir_context.h"
23
24 namespace spvtools {
25 namespace opt {
26 namespace {
27
28 using ::testing::Eq;
29
TEST(FunctionTest,HasEarlyReturn)30 TEST(FunctionTest, HasEarlyReturn) {
31 std::string shader = R"(
32 OpCapability Shader
33 %1 = OpExtInstImport "GLSL.std.450"
34 OpMemoryModel Logical GLSL450
35 OpEntryPoint Vertex %6 "main"
36
37 ; Types
38 %2 = OpTypeBool
39 %3 = OpTypeVoid
40 %4 = OpTypeFunction %3
41
42 ; Constants
43 %5 = OpConstantTrue %2
44
45 ; main function without early return
46 %6 = OpFunction %3 None %4
47 %7 = OpLabel
48 OpBranch %8
49 %8 = OpLabel
50 OpBranch %9
51 %9 = OpLabel
52 OpBranch %10
53 %10 = OpLabel
54 OpReturn
55 OpFunctionEnd
56
57 ; function with early return
58 %11 = OpFunction %3 None %4
59 %12 = OpLabel
60 OpSelectionMerge %15 None
61 OpBranchConditional %5 %13 %14
62 %13 = OpLabel
63 OpReturn
64 %14 = OpLabel
65 OpBranch %15
66 %15 = OpLabel
67 OpReturn
68 OpFunctionEnd
69 )";
70
71 const auto context =
72 BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, shader,
73 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
74
75 // Tests |function| without early return.
76 auto* function = spvtest::GetFunction(context->module(), 6);
77 ASSERT_FALSE(function->HasEarlyReturn());
78
79 // Tests |function| with early return.
80 function = spvtest::GetFunction(context->module(), 11);
81 ASSERT_TRUE(function->HasEarlyReturn());
82 }
83
TEST(FunctionTest,IsNotRecursive)84 TEST(FunctionTest, IsNotRecursive) {
85 const std::string text = R"(
86 OpCapability Shader
87 OpMemoryModel Logical GLSL450
88 OpEntryPoint Fragment %1 "main"
89 OpExecutionMode %1 OriginUpperLeft
90 OpDecorate %2 DescriptorSet 439418829
91 %void = OpTypeVoid
92 %4 = OpTypeFunction %void
93 %float = OpTypeFloat 32
94 %_struct_6 = OpTypeStruct %float %float
95 %7 = OpTypeFunction %_struct_6
96 %1 = OpFunction %void Pure|Const %4
97 %8 = OpLabel
98 %2 = OpFunctionCall %_struct_6 %9
99 OpKill
100 OpFunctionEnd
101 %9 = OpFunction %_struct_6 None %7
102 %10 = OpLabel
103 %11 = OpFunctionCall %_struct_6 %12
104 OpUnreachable
105 OpFunctionEnd
106 %12 = OpFunction %_struct_6 None %7
107 %13 = OpLabel
108 OpUnreachable
109 OpFunctionEnd
110 )";
111
112 std::unique_ptr<IRContext> ctx =
113 spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
114 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
115 auto* func = spvtest::GetFunction(ctx->module(), 9);
116 EXPECT_FALSE(func->IsRecursive());
117
118 func = spvtest::GetFunction(ctx->module(), 12);
119 EXPECT_FALSE(func->IsRecursive());
120 }
121
TEST(FunctionTest,IsDirectlyRecursive)122 TEST(FunctionTest, IsDirectlyRecursive) {
123 const std::string text = R"(
124 OpCapability Shader
125 OpMemoryModel Logical GLSL450
126 OpEntryPoint Fragment %1 "main"
127 OpExecutionMode %1 OriginUpperLeft
128 OpDecorate %2 DescriptorSet 439418829
129 %void = OpTypeVoid
130 %4 = OpTypeFunction %void
131 %float = OpTypeFloat 32
132 %_struct_6 = OpTypeStruct %float %float
133 %7 = OpTypeFunction %_struct_6
134 %1 = OpFunction %void Pure|Const %4
135 %8 = OpLabel
136 %2 = OpFunctionCall %_struct_6 %9
137 OpKill
138 OpFunctionEnd
139 %9 = OpFunction %_struct_6 None %7
140 %10 = OpLabel
141 %11 = OpFunctionCall %_struct_6 %9
142 OpUnreachable
143 OpFunctionEnd
144 )";
145
146 std::unique_ptr<IRContext> ctx =
147 spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
148 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
149 auto* func = spvtest::GetFunction(ctx->module(), 9);
150 EXPECT_TRUE(func->IsRecursive());
151 }
152
TEST(FunctionTest,IsIndirectlyRecursive)153 TEST(FunctionTest, IsIndirectlyRecursive) {
154 const std::string text = R"(
155 OpCapability Shader
156 OpMemoryModel Logical GLSL450
157 OpEntryPoint Fragment %1 "main"
158 OpExecutionMode %1 OriginUpperLeft
159 OpDecorate %2 DescriptorSet 439418829
160 %void = OpTypeVoid
161 %4 = OpTypeFunction %void
162 %float = OpTypeFloat 32
163 %_struct_6 = OpTypeStruct %float %float
164 %7 = OpTypeFunction %_struct_6
165 %1 = OpFunction %void Pure|Const %4
166 %8 = OpLabel
167 %2 = OpFunctionCall %_struct_6 %9
168 OpKill
169 OpFunctionEnd
170 %9 = OpFunction %_struct_6 None %7
171 %10 = OpLabel
172 %11 = OpFunctionCall %_struct_6 %12
173 OpUnreachable
174 OpFunctionEnd
175 %12 = OpFunction %_struct_6 None %7
176 %13 = OpLabel
177 %14 = OpFunctionCall %_struct_6 %9
178 OpUnreachable
179 OpFunctionEnd
180 )";
181
182 std::unique_ptr<IRContext> ctx =
183 spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
184 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
185 auto* func = spvtest::GetFunction(ctx->module(), 9);
186 EXPECT_TRUE(func->IsRecursive());
187
188 func = spvtest::GetFunction(ctx->module(), 12);
189 EXPECT_TRUE(func->IsRecursive());
190 }
191
TEST(FunctionTest,IsNotRecuriseCallingRecursive)192 TEST(FunctionTest, IsNotRecuriseCallingRecursive) {
193 const std::string text = R"(
194 OpCapability Shader
195 OpMemoryModel Logical GLSL450
196 OpEntryPoint Fragment %1 "main"
197 OpExecutionMode %1 OriginUpperLeft
198 OpDecorate %2 DescriptorSet 439418829
199 %void = OpTypeVoid
200 %4 = OpTypeFunction %void
201 %float = OpTypeFloat 32
202 %_struct_6 = OpTypeStruct %float %float
203 %7 = OpTypeFunction %_struct_6
204 %1 = OpFunction %void Pure|Const %4
205 %8 = OpLabel
206 %2 = OpFunctionCall %_struct_6 %9
207 OpKill
208 OpFunctionEnd
209 %9 = OpFunction %_struct_6 None %7
210 %10 = OpLabel
211 %11 = OpFunctionCall %_struct_6 %9
212 OpUnreachable
213 OpFunctionEnd
214 )";
215
216 std::unique_ptr<IRContext> ctx =
217 spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
218 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
219 auto* func = spvtest::GetFunction(ctx->module(), 1);
220 EXPECT_FALSE(func->IsRecursive());
221 }
222
TEST(FunctionTest,NonSemanticInfoSkipIteration)223 TEST(FunctionTest, NonSemanticInfoSkipIteration) {
224 const std::string text = R"(
225 OpCapability Shader
226 OpCapability Linkage
227 OpExtension "SPV_KHR_non_semantic_info"
228 %1 = OpExtInstImport "NonSemantic.Test"
229 OpMemoryModel Logical GLSL450
230 %2 = OpTypeVoid
231 %3 = OpTypeFunction %2
232 %4 = OpFunction %2 None %3
233 %5 = OpLabel
234 %6 = OpExtInst %2 %1 1
235 OpReturn
236 OpFunctionEnd
237 %7 = OpExtInst %2 %1 2
238 %8 = OpExtInst %2 %1 3
239 )";
240
241 std::unique_ptr<IRContext> ctx =
242 spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
243 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
244 auto* func = spvtest::GetFunction(ctx->module(), 4);
245 ASSERT_TRUE(func != nullptr);
246 std::unordered_set<uint32_t> non_semantic_ids;
247 func->ForEachInst(
248 [&non_semantic_ids](const Instruction* inst) {
249 if (inst->opcode() == spv::Op::OpExtInst) {
250 non_semantic_ids.insert(inst->result_id());
251 }
252 },
253 true, false);
254
255 EXPECT_EQ(1, non_semantic_ids.count(6));
256 EXPECT_EQ(0, non_semantic_ids.count(7));
257 EXPECT_EQ(0, non_semantic_ids.count(8));
258 }
259
TEST(FunctionTest,NonSemanticInfoIncludeIteration)260 TEST(FunctionTest, NonSemanticInfoIncludeIteration) {
261 const std::string text = R"(
262 OpCapability Shader
263 OpCapability Linkage
264 OpExtension "SPV_KHR_non_semantic_info"
265 %1 = OpExtInstImport "NonSemantic.Test"
266 OpMemoryModel Logical GLSL450
267 %2 = OpTypeVoid
268 %3 = OpTypeFunction %2
269 %4 = OpFunction %2 None %3
270 %5 = OpLabel
271 %6 = OpExtInst %2 %1 1
272 OpReturn
273 OpFunctionEnd
274 %7 = OpExtInst %2 %1 2
275 %8 = OpExtInst %2 %1 3
276 )";
277
278 std::unique_ptr<IRContext> ctx =
279 spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
280 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
281 auto* func = spvtest::GetFunction(ctx->module(), 4);
282 ASSERT_TRUE(func != nullptr);
283 std::unordered_set<uint32_t> non_semantic_ids;
284 func->ForEachInst(
285 [&non_semantic_ids](const Instruction* inst) {
286 if (inst->opcode() == spv::Op::OpExtInst) {
287 non_semantic_ids.insert(inst->result_id());
288 }
289 },
290 true, true);
291
292 EXPECT_EQ(1, non_semantic_ids.count(6));
293 EXPECT_EQ(1, non_semantic_ids.count(7));
294 EXPECT_EQ(1, non_semantic_ids.count(8));
295 }
296
TEST(FunctionTest,ReorderBlocksinStructuredOrder)297 TEST(FunctionTest, ReorderBlocksinStructuredOrder) {
298 // The spir-v has the basic block in a random order. We want to reorder them
299 // in structured order.
300 const std::string text = R"(
301 OpCapability Shader
302 OpMemoryModel Logical GLSL450
303 OpEntryPoint Fragment %100 "PSMain"
304 OpExecutionMode %PSMain OriginUpperLeft
305 OpSource HLSL 600
306 %int = OpTypeInt 32 1
307 %void = OpTypeVoid
308 %19 = OpTypeFunction %void
309 %bool = OpTypeBool
310 %undef_bool = OpUndef %bool
311 %undef_int = OpUndef %int
312 %100 = OpFunction %void None %19
313 %11 = OpLabel
314 OpSelectionMerge %10 None
315 OpSwitch %undef_int %3 0 %2 10 %1
316 %2 = OpLabel
317 OpReturn
318 %7 = OpLabel
319 OpBranch %8
320 %3 = OpLabel
321 OpBranch %4
322 %10 = OpLabel
323 OpReturn
324 %9 = OpLabel
325 OpBranch %10
326 %8 = OpLabel
327 OpBranch %4
328 %4 = OpLabel
329 OpLoopMerge %9 %8 None
330 OpBranchConditional %undef_bool %5 %9
331 %1 = OpLabel
332 OpReturn
333 %6 = OpLabel
334 OpBranch %7
335 %5 = OpLabel
336 OpSelectionMerge %7 None
337 OpBranchConditional %undef_bool %6 %7
338 OpFunctionEnd
339 )";
340
341 std::unique_ptr<IRContext> ctx =
342 spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
343 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
344 ASSERT_TRUE(ctx);
345 auto* func = spvtest::GetFunction(ctx->module(), 100);
346 ASSERT_TRUE(func);
347 func->ReorderBasicBlocksInStructuredOrder();
348
349 auto first_block = func->begin();
350 auto bb = first_block;
351 for (++bb; bb != func->end(); ++bb) {
352 EXPECT_EQ(bb->id(), (bb - first_block));
353 }
354 }
355
356 } // namespace
357 } // namespace opt
358 } // namespace spvtools
359