1 // Copyright (c) 2015-2016 The Khronos Group Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Assembler tests for instructions in the "Control Flow" section of the
16 // SPIR-V spec.
17 
18 #include <sstream>
19 #include <string>
20 #include <tuple>
21 #include <vector>
22 
23 #include "gmock/gmock.h"
24 #include "test/test_fixture.h"
25 #include "test/unit_spirv.h"
26 
27 namespace spvtools {
28 namespace {
29 
30 using spvtest::Concatenate;
31 using spvtest::EnumCase;
32 using spvtest::MakeInstruction;
33 using spvtest::TextToBinaryTest;
34 using ::testing::Combine;
35 using ::testing::Eq;
36 using ::testing::TestWithParam;
37 using ::testing::Values;
38 using ::testing::ValuesIn;
39 
40 // Test OpSelectionMerge
41 
42 using OpSelectionMergeTest = spvtest::TextToBinaryTestBase<
43     TestWithParam<EnumCase<spv::SelectionControlMask>>>;
44 
TEST_P(OpSelectionMergeTest,AnySingleSelectionControlMask)45 TEST_P(OpSelectionMergeTest, AnySingleSelectionControlMask) {
46   const std::string input = "OpSelectionMerge %1 " + GetParam().name();
47   EXPECT_THAT(CompiledInstructions(input),
48               Eq(MakeInstruction(spv::Op::OpSelectionMerge,
49                                  {1, uint32_t(GetParam().value())})));
50 }
51 
52 // clang-format off
53 #define CASE(VALUE,NAME) { spv::SelectionControlMask::VALUE, NAME}
54 INSTANTIATE_TEST_SUITE_P(TextToBinarySelectionMerge, OpSelectionMergeTest,
55                         ValuesIn(std::vector<EnumCase<spv::SelectionControlMask>>{
56                             CASE(MaskNone, "None"),
57                             CASE(Flatten, "Flatten"),
58                             CASE(DontFlatten, "DontFlatten"),
59                         }));
60 #undef CASE
61 // clang-format on
62 
TEST_F(OpSelectionMergeTest,CombinedSelectionControlMask)63 TEST_F(OpSelectionMergeTest, CombinedSelectionControlMask) {
64   const std::string input = "OpSelectionMerge %1 Flatten|DontFlatten";
65   const uint32_t expected_mask =
66       uint32_t(spv::SelectionControlMask::Flatten |
67                spv::SelectionControlMask::DontFlatten);
68   EXPECT_THAT(
69       CompiledInstructions(input),
70       Eq(MakeInstruction(spv::Op::OpSelectionMerge, {1, expected_mask})));
71 }
72 
TEST_F(OpSelectionMergeTest,WrongSelectionControl)73 TEST_F(OpSelectionMergeTest, WrongSelectionControl) {
74   // Case sensitive: "flatten" != "Flatten" and thus wrong.
75   EXPECT_THAT(CompileFailure("OpSelectionMerge %1 flatten|DontFlatten"),
76               Eq("Invalid selection control operand 'flatten|DontFlatten'."));
77 }
78 
79 // Test OpLoopMerge
80 
81 using OpLoopMergeTest = spvtest::TextToBinaryTestBase<
82     TestWithParam<std::tuple<spv_target_env, EnumCase<int>>>>;
83 
TEST_P(OpLoopMergeTest,AnySingleLoopControlMask)84 TEST_P(OpLoopMergeTest, AnySingleLoopControlMask) {
85   const auto ctrl = std::get<1>(GetParam());
86   std::ostringstream input;
87   input << "OpLoopMerge %merge %continue " << ctrl.name();
88   for (auto num : ctrl.operands()) input << " " << num;
89   EXPECT_THAT(CompiledInstructions(input.str(), std::get<0>(GetParam())),
90               Eq(MakeInstruction(spv::Op::OpLoopMerge, {1, 2, ctrl.value()},
91                                  ctrl.operands())));
92 }
93 
94 #define CASE(VALUE, NAME) \
95   { int32_t(spv::LoopControlMask::VALUE), NAME }
96 #define CASE1(VALUE, NAME, PARM)                         \
97   {                                                      \
98     int32_t(spv::LoopControlMask::VALUE), NAME, { PARM } \
99   }
100 INSTANTIATE_TEST_SUITE_P(
101     TextToBinaryLoopMerge, OpLoopMergeTest,
102     Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1),
103             ValuesIn(std::vector<EnumCase<int>>{
104                 // clang-format off
105                 CASE(MaskNone, "None"),
106                 CASE(Unroll, "Unroll"),
107                 CASE(DontUnroll, "DontUnroll"),
108                 // clang-format on
109             })));
110 
111 INSTANTIATE_TEST_SUITE_P(
112     TextToBinaryLoopMergeV11, OpLoopMergeTest,
113     Combine(Values(SPV_ENV_UNIVERSAL_1_1),
114             ValuesIn(std::vector<EnumCase<int>>{
115                 // clang-format off
116                 CASE(DependencyInfinite, "DependencyInfinite"),
117                 CASE1(DependencyLength, "DependencyLength", 234),
118                 {int32_t(spv::LoopControlMask::Unroll|spv::LoopControlMask::DependencyLength),
119                       "DependencyLength|Unroll", {33}},
120                 // clang-format on
121             })));
122 #undef CASE
123 #undef CASE1
124 
TEST_F(OpLoopMergeTest,CombinedLoopControlMask)125 TEST_F(OpLoopMergeTest, CombinedLoopControlMask) {
126   const std::string input = "OpLoopMerge %merge %continue Unroll|DontUnroll";
127   const uint32_t expected_mask =
128       uint32_t(spv::LoopControlMask::Unroll | spv::LoopControlMask::DontUnroll);
129   EXPECT_THAT(CompiledInstructions(input),
130               Eq(MakeInstruction(spv::Op::OpLoopMerge, {1, 2, expected_mask})));
131 }
132 
TEST_F(OpLoopMergeTest,WrongLoopControl)133 TEST_F(OpLoopMergeTest, WrongLoopControl) {
134   EXPECT_THAT(CompileFailure("OpLoopMerge %m %c none"),
135               Eq("Invalid loop control operand 'none'."));
136 }
137 
138 // Test OpSwitch
139 
TEST_F(TextToBinaryTest,SwitchGoodZeroTargets)140 TEST_F(TextToBinaryTest, SwitchGoodZeroTargets) {
141   EXPECT_THAT(CompiledInstructions("OpSwitch %selector %default"),
142               Eq(MakeInstruction(spv::Op::OpSwitch, {1, 2})));
143 }
144 
TEST_F(TextToBinaryTest,SwitchGoodOneTarget)145 TEST_F(TextToBinaryTest, SwitchGoodOneTarget) {
146   EXPECT_THAT(
147       CompiledInstructions("%1 = OpTypeInt 32 0\n"
148                            "%2 = OpConstant %1 52\n"
149                            "OpSwitch %2 %default 12 %target0"),
150       Eq(Concatenate({MakeInstruction(spv::Op::OpTypeInt, {1, 32, 0}),
151                       MakeInstruction(spv::Op::OpConstant, {1, 2, 52}),
152                       MakeInstruction(spv::Op::OpSwitch, {2, 3, 12, 4})})));
153 }
154 
TEST_F(TextToBinaryTest,SwitchGoodTwoTargets)155 TEST_F(TextToBinaryTest, SwitchGoodTwoTargets) {
156   EXPECT_THAT(
157       CompiledInstructions("%1 = OpTypeInt 32 0\n"
158                            "%2 = OpConstant %1 52\n"
159                            "OpSwitch %2 %default 12 %target0 42 %target1"),
160       Eq(Concatenate({
161           MakeInstruction(spv::Op::OpTypeInt, {1, 32, 0}),
162           MakeInstruction(spv::Op::OpConstant, {1, 2, 52}),
163           MakeInstruction(spv::Op::OpSwitch, {2, 3, 12, 4, 42, 5}),
164       })));
165 }
166 
TEST_F(TextToBinaryTest,SwitchBadMissingSelector)167 TEST_F(TextToBinaryTest, SwitchBadMissingSelector) {
168   EXPECT_THAT(CompileFailure("OpSwitch"),
169               Eq("Expected operand for OpSwitch instruction, but found the end "
170                  "of the stream."));
171 }
172 
TEST_F(TextToBinaryTest,SwitchBadInvalidSelector)173 TEST_F(TextToBinaryTest, SwitchBadInvalidSelector) {
174   EXPECT_THAT(CompileFailure("OpSwitch 12"),
175               Eq("Expected id to start with %."));
176 }
177 
TEST_F(TextToBinaryTest,SwitchBadMissingDefault)178 TEST_F(TextToBinaryTest, SwitchBadMissingDefault) {
179   EXPECT_THAT(CompileFailure("OpSwitch %selector"),
180               Eq("Expected operand for OpSwitch instruction, but found the end "
181                  "of the stream."));
182 }
183 
TEST_F(TextToBinaryTest,SwitchBadInvalidDefault)184 TEST_F(TextToBinaryTest, SwitchBadInvalidDefault) {
185   EXPECT_THAT(CompileFailure("OpSwitch %selector 12"),
186               Eq("Expected id to start with %."));
187 }
188 
TEST_F(TextToBinaryTest,SwitchBadInvalidLiteral)189 TEST_F(TextToBinaryTest, SwitchBadInvalidLiteral) {
190   // The assembler recognizes "OpSwitch %selector %default" as a complete
191   // instruction.  Then it tries to parse "%abc" as the start of a new
192   // instruction, but can't since it hits the end of stream.
193   const auto input = R"(%i32 = OpTypeInt 32 0
194                         %selector = OpConstant %i32 42
195                         OpSwitch %selector %default %abc)";
196   EXPECT_THAT(CompileFailure(input), Eq("Expected '=', found end of stream."));
197 }
198 
TEST_F(TextToBinaryTest,SwitchBadMissingTarget)199 TEST_F(TextToBinaryTest, SwitchBadMissingTarget) {
200   EXPECT_THAT(CompileFailure("%1 = OpTypeInt 32 0\n"
201                              "%2 = OpConstant %1 52\n"
202                              "OpSwitch %2 %default 12"),
203               Eq("Expected operand for OpSwitch instruction, but found the end "
204                  "of the stream."));
205 }
206 
207 // A test case for an OpSwitch.
208 // It is also parameterized to test encodings OpConstant
209 // integer literals.  This can capture both single and multi-word
210 // integer literal tests.
211 struct SwitchTestCase {
212   std::string constant_type_args;
213   std::string constant_value_arg;
214   std::string case_value_arg;
215   std::vector<uint32_t> expected_instructions;
216 };
217 
218 using OpSwitchValidTest =
219     spvtest::TextToBinaryTestBase<TestWithParam<SwitchTestCase>>;
220 
221 // Tests the encoding of OpConstant literal values, and also
222 // the literal integer cases in an OpSwitch.  This can
223 // test both single and multi-word integer literal encodings.
TEST_P(OpSwitchValidTest,ValidTypes)224 TEST_P(OpSwitchValidTest, ValidTypes) {
225   const std::string input = "%1 = OpTypeInt " + GetParam().constant_type_args +
226                             "\n"
227                             "%2 = OpConstant %1 " +
228                             GetParam().constant_value_arg +
229                             "\n"
230                             "OpSwitch %2 %default " +
231                             GetParam().case_value_arg + " %4\n";
232   std::vector<uint32_t> instructions;
233   EXPECT_THAT(CompiledInstructions(input),
234               Eq(GetParam().expected_instructions));
235 }
236 
237 // Constructs a SwitchTestCase from the given integer_width, signedness,
238 // constant value string, and expected encoded constant.
MakeSwitchTestCase(uint32_t integer_width,uint32_t integer_signedness,std::string constant_str,std::vector<uint32_t> encoded_constant,std::string case_value_str,std::vector<uint32_t> encoded_case_value)239 SwitchTestCase MakeSwitchTestCase(uint32_t integer_width,
240                                   uint32_t integer_signedness,
241                                   std::string constant_str,
242                                   std::vector<uint32_t> encoded_constant,
243                                   std::string case_value_str,
244                                   std::vector<uint32_t> encoded_case_value) {
245   std::stringstream ss;
246   ss << integer_width << " " << integer_signedness;
247   return SwitchTestCase{
248       ss.str(),
249       constant_str,
250       case_value_str,
251       {Concatenate(
252           {MakeInstruction(spv::Op::OpTypeInt,
253                            {1, integer_width, integer_signedness}),
254            MakeInstruction(spv::Op::OpConstant,
255                            Concatenate({{1, 2}, encoded_constant})),
256            MakeInstruction(spv::Op::OpSwitch,
257                            Concatenate({{2, 3}, encoded_case_value, {4}}))})}};
258 }
259 
260 INSTANTIATE_TEST_SUITE_P(
261     TextToBinaryOpSwitchValid1Word, OpSwitchValidTest,
262     ValuesIn(std::vector<SwitchTestCase>({
263         MakeSwitchTestCase(32, 0, "42", {42}, "100", {100}),
264         MakeSwitchTestCase(32, 1, "-1", {0xffffffff}, "100", {100}),
265         // SPIR-V 1.0 Rev 1 clarified that for an integer narrower than 32-bits,
266         // its bits will appear in the lower order bits of the 32-bit word, and
267         // a signed integer is sign-extended.
268         MakeSwitchTestCase(7, 0, "127", {127}, "100", {100}),
269         MakeSwitchTestCase(14, 0, "99", {99}, "100", {100}),
270         MakeSwitchTestCase(16, 0, "65535", {65535}, "100", {100}),
271         MakeSwitchTestCase(16, 1, "101", {101}, "100", {100}),
272         // Demonstrate sign extension
273         MakeSwitchTestCase(16, 1, "-2", {0xfffffffe}, "100", {100}),
274         // Hex cases
275         MakeSwitchTestCase(16, 1, "0x7ffe", {0x7ffe}, "0x1234", {0x1234}),
276         MakeSwitchTestCase(16, 1, "0x8000", {0xffff8000}, "0x8100",
277                            {0xffff8100}),
278         MakeSwitchTestCase(16, 0, "0x8000", {0x00008000}, "0x8100", {0x8100}),
279     })));
280 
281 // NB: The words LOW ORDER bits show up first.
282 INSTANTIATE_TEST_SUITE_P(
283     TextToBinaryOpSwitchValid2Words, OpSwitchValidTest,
284     ValuesIn(std::vector<SwitchTestCase>({
285         MakeSwitchTestCase(33, 0, "101", {101, 0}, "500", {500, 0}),
286         MakeSwitchTestCase(48, 1, "-1", {0xffffffff, 0xffffffff}, "900",
287                            {900, 0}),
288         MakeSwitchTestCase(64, 1, "-2", {0xfffffffe, 0xffffffff}, "-5",
289                            {0xfffffffb, uint32_t(-1)}),
290         // Hex cases
291         MakeSwitchTestCase(48, 1, "0x7fffffffffff", {0xffffffff, 0x00007fff},
292                            "100", {100, 0}),
293         MakeSwitchTestCase(48, 1, "0x800000000000", {0x00000000, 0xffff8000},
294                            "0x800000000000", {0x00000000, 0xffff8000}),
295         MakeSwitchTestCase(48, 0, "0x800000000000", {0x00000000, 0x00008000},
296                            "0x800000000000", {0x00000000, 0x00008000}),
297         MakeSwitchTestCase(63, 0, "0x500000000", {0, 5}, "12", {12, 0}),
298         MakeSwitchTestCase(64, 0, "0x600000000", {0, 6}, "12", {12, 0}),
299         MakeSwitchTestCase(64, 1, "0x700000123", {0x123, 7}, "12", {12, 0}),
300     })));
301 
302 using ControlFlowRoundTripTest = RoundTripTest;
303 
TEST_P(ControlFlowRoundTripTest,DisassemblyEqualsAssemblyInput)304 TEST_P(ControlFlowRoundTripTest, DisassemblyEqualsAssemblyInput) {
305   const std::string assembly = GetParam();
306   EXPECT_THAT(EncodeAndDecodeSuccessfully(assembly), Eq(assembly)) << assembly;
307 }
308 
309 INSTANTIATE_TEST_SUITE_P(
310     OpSwitchRoundTripUnsignedIntegers, ControlFlowRoundTripTest,
311     ValuesIn(std::vector<std::string>({
312         // Unsigned 16-bit.
313         "%1 = OpTypeInt 16 0\n%2 = OpConstant %1 65535\nOpSwitch %2 %3\n",
314         // Unsigned 32-bit, three non-default cases.
315         "%1 = OpTypeInt 32 0\n%2 = OpConstant %1 123456\n"
316         "OpSwitch %2 %3 100 %4 102 %5 1000000 %6\n",
317         // Unsigned 48-bit, three non-default cases.
318         "%1 = OpTypeInt 48 0\n%2 = OpConstant %1 5000000000\n"
319         "OpSwitch %2 %3 100 %4 102 %5 6000000000 %6\n",
320         // Unsigned 64-bit, three non-default cases.
321         "%1 = OpTypeInt 64 0\n%2 = OpConstant %1 9223372036854775807\n"
322         "OpSwitch %2 %3 100 %4 102 %5 9000000000000000000 %6\n",
323     })));
324 
325 INSTANTIATE_TEST_SUITE_P(
326     OpSwitchRoundTripSignedIntegers, ControlFlowRoundTripTest,
327     ValuesIn(std::vector<std::string>{
328         // Signed 16-bit, with two non-default cases
329         "%1 = OpTypeInt 16 1\n%2 = OpConstant %1 32767\n"
330         "OpSwitch %2 %3 99 %4 -102 %5\n",
331         "%1 = OpTypeInt 16 1\n%2 = OpConstant %1 -32768\n"
332         "OpSwitch %2 %3 99 %4 -102 %5\n",
333         // Signed 32-bit, two non-default cases.
334         "%1 = OpTypeInt 32 1\n%2 = OpConstant %1 -123456\n"
335         "OpSwitch %2 %3 100 %4 -123456 %5\n",
336         "%1 = OpTypeInt 32 1\n%2 = OpConstant %1 123456\n"
337         "OpSwitch %2 %3 100 %4 123456 %5\n",
338         // Signed 48-bit, three non-default cases.
339         "%1 = OpTypeInt 48 1\n%2 = OpConstant %1 5000000000\n"
340         "OpSwitch %2 %3 100 %4 -7000000000 %5 6000000000 %6\n",
341         "%1 = OpTypeInt 48 1\n%2 = OpConstant %1 -5000000000\n"
342         "OpSwitch %2 %3 100 %4 -7000000000 %5 6000000000 %6\n",
343         // Signed 64-bit, three non-default cases.
344         "%1 = OpTypeInt 64 1\n%2 = OpConstant %1 9223372036854775807\n"
345         "OpSwitch %2 %3 100 %4 7000000000 %5 -1000000000000000000 %6\n",
346         "%1 = OpTypeInt 64 1\n%2 = OpConstant %1 -9223372036854775808\n"
347         "OpSwitch %2 %3 100 %4 7000000000 %5 -1000000000000000000 %6\n",
348     }));
349 
350 using OpSwitchInvalidTypeTestCase =
351     spvtest::TextToBinaryTestBase<TestWithParam<std::string>>;
352 
TEST_P(OpSwitchInvalidTypeTestCase,InvalidTypes)353 TEST_P(OpSwitchInvalidTypeTestCase, InvalidTypes) {
354   const std::string input =
355       "%1 = " + GetParam() +
356       "\n"
357       "%3 = OpCopyObject %1 %2\n"  // We only care the type of the expression
358       "     OpSwitch %3 %default 32 %c\n";
359   EXPECT_THAT(CompileFailure(input),
360               Eq("The selector operand for OpSwitch must be the result of an "
361                  "instruction that generates an integer scalar"));
362 }
363 
364 // clang-format off
365 INSTANTIATE_TEST_SUITE_P(
366     TextToBinaryOpSwitchInvalidTests, OpSwitchInvalidTypeTestCase,
367     ValuesIn(std::vector<std::string>{
368       {"OpTypeVoid",
369        "OpTypeBool",
370        "OpTypeFloat 32",
371        "OpTypeVector %a 32",
372        "OpTypeMatrix %a 32",
373        "OpTypeImage %a 1D 0 0 0 0 Unknown",
374        "OpTypeSampler",
375        "OpTypeSampledImage %a",
376        "OpTypeArray %a %b",
377        "OpTypeRuntimeArray %a",
378        "OpTypeStruct %a",
379        "OpTypeOpaque \"Foo\"",
380        "OpTypePointer UniformConstant %a",
381        "OpTypeFunction %a %b",
382        "OpTypeEvent",
383        "OpTypeDeviceEvent",
384        "OpTypeReserveId",
385        "OpTypeQueue",
386        "OpTypePipe ReadOnly",
387 
388        // Skip OpTypeForwardPointer because it doesn't even produce a result
389        // ID.
390 
391        // At least one thing that isn't a type at all
392        "OpNot %a %b"
393       },
394     }));
395 // clang-format on
396 
397 using OpKillTest = spvtest::TextToBinaryTest;
398 
399 INSTANTIATE_TEST_SUITE_P(OpKillTest, ControlFlowRoundTripTest,
400                          Values("OpKill\n"));
401 
TEST_F(OpKillTest,ExtraArgsAssemblyError)402 TEST_F(OpKillTest, ExtraArgsAssemblyError) {
403   const std::string input = "OpKill 1";
404   EXPECT_THAT(CompileFailure(input),
405               Eq("Expected <opcode> or <result-id> at the beginning of an "
406                  "instruction, found '1'."));
407 }
408 
409 using OpTerminateInvocationTest = spvtest::TextToBinaryTest;
410 
411 INSTANTIATE_TEST_SUITE_P(OpTerminateInvocationTest, ControlFlowRoundTripTest,
412                          Values("OpTerminateInvocation\n"));
413 
TEST_F(OpTerminateInvocationTest,ExtraArgsAssemblyError)414 TEST_F(OpTerminateInvocationTest, ExtraArgsAssemblyError) {
415   const std::string input = "OpTerminateInvocation 1";
416   EXPECT_THAT(CompileFailure(input),
417               Eq("Expected <opcode> or <result-id> at the beginning of an "
418                  "instruction, found '1'."));
419 }
420 
421 // TODO(dneto): OpPhi
422 // TODO(dneto): OpLoopMerge
423 // TODO(dneto): OpLabel
424 // TODO(dneto): OpBranch
425 // TODO(dneto): OpSwitch
426 // TODO(dneto): OpReturn
427 // TODO(dneto): OpReturnValue
428 // TODO(dneto): OpUnreachable
429 // TODO(dneto): OpLifetimeStart
430 // TODO(dneto): OpLifetimeStop
431 
432 }  // namespace
433 }  // namespace spvtools
434