1 // Copyright (c) 2016 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <memory>
16 #include <unordered_map>
17 #include <unordered_set>
18 #include <utility>
19 #include <vector>
20 
21 #include "gmock/gmock.h"
22 #include "gtest/gtest.h"
23 #include "source/opt/build_module.h"
24 #include "source/opt/def_use_manager.h"
25 #include "source/opt/ir_context.h"
26 #include "source/opt/module.h"
27 #include "spirv-tools/libspirv.hpp"
28 #include "test/opt/pass_fixture.h"
29 #include "test/opt/pass_utils.h"
30 
31 namespace spvtools {
32 namespace opt {
33 namespace analysis {
34 namespace {
35 
36 using ::testing::Contains;
37 using ::testing::UnorderedElementsAre;
38 using ::testing::UnorderedElementsAreArray;
39 
40 // Returns the number of uses of |id|.
NumUses(const std::unique_ptr<IRContext> & context,uint32_t id)41 uint32_t NumUses(const std::unique_ptr<IRContext>& context, uint32_t id) {
42   uint32_t count = 0;
43   context->get_def_use_mgr()->ForEachUse(
44       id, [&count](Instruction*, uint32_t) { ++count; });
45   return count;
46 }
47 
48 // Returns the opcode of each use of |id|.
49 //
50 // If |id| is used multiple times in a single instruction, that instruction's
51 // opcode will appear a corresponding number of times.
GetUseOpcodes(const std::unique_ptr<IRContext> & context,uint32_t id)52 std::vector<spv::Op> GetUseOpcodes(const std::unique_ptr<IRContext>& context,
53                                    uint32_t id) {
54   std::vector<spv::Op> opcodes;
55   context->get_def_use_mgr()->ForEachUse(
56       id, [&opcodes](Instruction* user, uint32_t) {
57         opcodes.push_back(user->opcode());
58       });
59   return opcodes;
60 }
61 
62 // Disassembles the given |inst| and returns the disassembly.
DisassembleInst(Instruction * inst)63 std::string DisassembleInst(Instruction* inst) {
64   SpirvTools tools(SPV_ENV_UNIVERSAL_1_1);
65 
66   std::vector<uint32_t> binary;
67   // We need this to generate the necessary header in the binary.
68   tools.Assemble("", &binary);
69   inst->ToBinaryWithoutAttachedDebugInsts(&binary);
70 
71   std::string text;
72   // We'll need to check the underlying id numbers.
73   // So turn off friendly names for ids.
74   tools.Disassemble(binary, &text, SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
75   while (!text.empty() && text.back() == '\n') text.pop_back();
76   return text;
77 }
78 
79 // A struct for holding expected id defs and uses.
80 struct InstDefUse {
81   using IdInstPair = std::pair<uint32_t, std::string>;
82   using IdInstsPair = std::pair<uint32_t, std::vector<std::string>>;
83 
84   // Ids and their corresponding def instructions.
85   std::vector<IdInstPair> defs;
86   // Ids and their corresponding use instructions.
87   std::vector<IdInstsPair> uses;
88 };
89 
90 // Checks that the |actual_defs| and |actual_uses| are in accord with
91 // |expected_defs_uses|.
CheckDef(const InstDefUse & expected_defs_uses,const DefUseManager::IdToDefMap & actual_defs)92 void CheckDef(const InstDefUse& expected_defs_uses,
93               const DefUseManager::IdToDefMap& actual_defs) {
94   // Check defs.
95   ASSERT_EQ(expected_defs_uses.defs.size(), actual_defs.size());
96   for (uint32_t i = 0; i < expected_defs_uses.defs.size(); ++i) {
97     const auto id = expected_defs_uses.defs[i].first;
98     const auto expected_def = expected_defs_uses.defs[i].second;
99     ASSERT_EQ(1u, actual_defs.count(id)) << "expected to def id [" << id << "]";
100     auto def = actual_defs.at(id);
101     if (def->opcode() != spv::Op::OpConstant) {
102       // Constants don't disassemble properly without a full context.
103       EXPECT_EQ(expected_def, DisassembleInst(actual_defs.at(id)));
104     }
105   }
106 }
107 
108 using UserMap = std::unordered_map<uint32_t, std::vector<Instruction*>>;
109 
110 // Creates a mapping of all definitions to their users (except OpConstant).
111 //
112 // OpConstants are skipped because they cannot be disassembled in isolation.
BuildAllUsers(const DefUseManager * mgr,uint32_t idBound)113 UserMap BuildAllUsers(const DefUseManager* mgr, uint32_t idBound) {
114   UserMap userMap;
115   for (uint32_t id = 0; id != idBound; ++id) {
116     if (mgr->GetDef(id)) {
117       mgr->ForEachUser(id, [id, &userMap](Instruction* user) {
118         if (user->opcode() != spv::Op::OpConstant) {
119           userMap[id].push_back(user);
120         }
121       });
122     }
123   }
124   return userMap;
125 }
126 
127 // Constants don't disassemble properly without a full context, so skip them as
128 // checks.
CheckUse(const InstDefUse & expected_defs_uses,const DefUseManager * mgr,uint32_t idBound)129 void CheckUse(const InstDefUse& expected_defs_uses, const DefUseManager* mgr,
130               uint32_t idBound) {
131   UserMap actual_uses = BuildAllUsers(mgr, idBound);
132   // Check uses.
133   ASSERT_EQ(expected_defs_uses.uses.size(), actual_uses.size());
134   for (uint32_t i = 0; i < expected_defs_uses.uses.size(); ++i) {
135     const auto id = expected_defs_uses.uses[i].first;
136     const auto& expected_uses = expected_defs_uses.uses[i].second;
137 
138     ASSERT_EQ(1u, actual_uses.count(id)) << "expected to use id [" << id << "]";
139     const auto& uses = actual_uses.at(id);
140 
141     ASSERT_EQ(expected_uses.size(), uses.size())
142         << "id [" << id << "] # uses: expected: " << expected_uses.size()
143         << " actual: " << uses.size();
144 
145     std::vector<std::string> actual_uses_disassembled;
146     for (const auto actual_use : uses) {
147       actual_uses_disassembled.emplace_back(DisassembleInst(actual_use));
148     }
149     EXPECT_THAT(actual_uses_disassembled,
150                 UnorderedElementsAreArray(expected_uses));
151   }
152 }
153 
154 // The following test case mimics how LLVM handles induction variables.
155 // But, yeah, it's not very readable. However, we only care about the id
156 // defs and uses. So, no need to make sure this is valid OpPhi construct.
157 const char kOpPhiTestFunction[] =
158     " %1 = OpTypeVoid "
159     " %6 = OpTypeInt 32 0 "
160     "%10 = OpTypeFloat 32 "
161     "%16 = OpTypeBool "
162     " %3 = OpTypeFunction %1 "
163     " %8 = OpConstant %6 0 "
164     "%18 = OpConstant %6 1 "
165     "%12 = OpConstant %10 1.0 "
166     " %2 = OpFunction %1 None %3 "
167     " %4 = OpLabel "
168     "      OpBranch %5 "
169 
170     " %5 = OpLabel "
171     " %7 = OpPhi %6 %8 %4 %9 %5 "
172     "%11 = OpPhi %10 %12 %4 %13 %5 "
173     " %9 = OpIAdd %6 %7 %8 "
174     "%13 = OpFAdd %10 %11 %12 "
175     "%17 = OpSLessThan %16 %7 %18 "
176     "      OpLoopMerge %19 %5 None "
177     "      OpBranchConditional %17 %5 %19 "
178 
179     "%19 = OpLabel "
180     "      OpReturn "
181     "      OpFunctionEnd";
182 
183 struct ParseDefUseCase {
184   const char* text;
185   InstDefUse du;
186 };
187 
188 using ParseDefUseTest = ::testing::TestWithParam<ParseDefUseCase>;
189 
TEST_P(ParseDefUseTest,Case)190 TEST_P(ParseDefUseTest, Case) {
191   const auto& tc = GetParam();
192 
193   // Build module.
194   const std::vector<const char*> text = {tc.text};
195   std::unique_ptr<IRContext> context =
196       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text),
197                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
198   ASSERT_NE(nullptr, context);
199 
200   // Analyze def and use.
201   DefUseManager manager(context->module());
202 
203   CheckDef(tc.du, manager.id_to_defs());
204   CheckUse(tc.du, &manager, context->module()->IdBound());
205 }
206 
207 // clang-format off
208 INSTANTIATE_TEST_SUITE_P(
209     TestCase, ParseDefUseTest,
210     ::testing::ValuesIn(std::vector<ParseDefUseCase>{
211         {"", {{}, {}}},                              // no instruction
212         {"OpMemoryModel Logical GLSL450", {{}, {}}}, // no def and use
213         { // single def, no use
214           "%1 = OpString \"wow\"",
215           {
216             {{1, "%1 = OpString \"wow\""}}, // defs
217             {}                              // uses
218           }
219         },
220         { // multiple def, no use
221           "%1 = OpString \"hello\" "
222           "%2 = OpString \"world\" "
223           "%3 = OpTypeVoid",
224           {
225             {  // defs
226               {1, "%1 = OpString \"hello\""},
227               {2, "%2 = OpString \"world\""},
228               {3, "%3 = OpTypeVoid"},
229             },
230             {} // uses
231           }
232         },
233         { // multiple def, multiple use
234           "%1 = OpTypeBool "
235           "%2 = OpTypeVector %1 3 "
236           "%3 = OpTypeMatrix %2 3",
237           {
238             { // defs
239               {1, "%1 = OpTypeBool"},
240               {2, "%2 = OpTypeVector %1 3"},
241               {3, "%3 = OpTypeMatrix %2 3"},
242             },
243             { // uses
244               {1, {"%2 = OpTypeVector %1 3"}},
245               {2, {"%3 = OpTypeMatrix %2 3"}},
246             }
247           }
248         },
249         { // multiple use of the same id
250           "%1 = OpTypeBool "
251           "%2 = OpTypeVector %1 2 "
252           "%3 = OpTypeVector %1 3 "
253           "%4 = OpTypeVector %1 4",
254           {
255             { // defs
256               {1, "%1 = OpTypeBool"},
257               {2, "%2 = OpTypeVector %1 2"},
258               {3, "%3 = OpTypeVector %1 3"},
259               {4, "%4 = OpTypeVector %1 4"},
260             },
261             { // uses
262               {1,
263                 {
264                   "%2 = OpTypeVector %1 2",
265                   "%3 = OpTypeVector %1 3",
266                   "%4 = OpTypeVector %1 4",
267                 }
268               },
269             }
270           }
271         },
272         { // labels
273           "%1 = OpTypeVoid "
274           "%2 = OpTypeBool "
275           "%3 = OpTypeFunction %1 "
276           "%4 = OpConstantTrue %2 "
277           "%5 = OpFunction %1 None %3 "
278 
279           "%6 = OpLabel "
280           "OpBranchConditional %4 %7 %8 "
281 
282           "%7 = OpLabel "
283           "OpBranch %7 "
284 
285           "%8 = OpLabel "
286           "OpReturn "
287 
288           "OpFunctionEnd",
289           {
290             { // defs
291               {1, "%1 = OpTypeVoid"},
292               {2, "%2 = OpTypeBool"},
293               {3, "%3 = OpTypeFunction %1"},
294               {4, "%4 = OpConstantTrue %2"},
295               {5, "%5 = OpFunction %1 None %3"},
296               {6, "%6 = OpLabel"},
297               {7, "%7 = OpLabel"},
298               {8, "%8 = OpLabel"},
299             },
300             { // uses
301               {1, {
302                     "%3 = OpTypeFunction %1",
303                     "%5 = OpFunction %1 None %3",
304                   }
305               },
306               {2, {"%4 = OpConstantTrue %2"}},
307               {3, {"%5 = OpFunction %1 None %3"}},
308               {4, {"OpBranchConditional %4 %7 %8"}},
309               {7,
310                 {
311                   "OpBranchConditional %4 %7 %8",
312                   "OpBranch %7",
313                 }
314               },
315               {8, {"OpBranchConditional %4 %7 %8"}},
316             }
317           }
318         },
319         { // cross function
320           "%1 = OpTypeBool "
321           "%3 = OpTypeFunction %1 "
322           "%2 = OpFunction %1 None %3 "
323 
324           "%4 = OpLabel "
325           "%5 = OpVariable %1 Function "
326           "%6 = OpFunctionCall %1 %2 %5 "
327           "OpReturnValue %6 "
328 
329           "OpFunctionEnd",
330           {
331             { // defs
332               {1, "%1 = OpTypeBool"},
333               {2, "%2 = OpFunction %1 None %3"},
334               {3, "%3 = OpTypeFunction %1"},
335               {4, "%4 = OpLabel"},
336               {5, "%5 = OpVariable %1 Function"},
337               {6, "%6 = OpFunctionCall %1 %2 %5"},
338             },
339             { // uses
340               {1,
341                 {
342                   "%2 = OpFunction %1 None %3",
343                   "%3 = OpTypeFunction %1",
344                   "%5 = OpVariable %1 Function",
345                   "%6 = OpFunctionCall %1 %2 %5",
346                 }
347               },
348               {2, {"%6 = OpFunctionCall %1 %2 %5"}},
349               {3, {"%2 = OpFunction %1 None %3"}},
350               {5, {"%6 = OpFunctionCall %1 %2 %5"}},
351               {6, {"OpReturnValue %6"}},
352             }
353           }
354         },
355         { // selection merge and loop merge
356           "%1 = OpTypeVoid "
357           "%3 = OpTypeFunction %1 "
358           "%10 = OpTypeBool "
359           "%8 = OpConstantTrue %10 "
360           "%2 = OpFunction %1 None %3 "
361 
362           "%4 = OpLabel "
363           "OpLoopMerge %5 %4 None "
364           "OpBranch %6 "
365 
366           "%5 = OpLabel "
367           "OpReturn "
368 
369           "%6 = OpLabel "
370           "OpSelectionMerge %7 None "
371           "OpBranchConditional %8 %9 %7 "
372 
373           "%7 = OpLabel "
374           "OpReturn "
375 
376           "%9 = OpLabel "
377           "OpReturn "
378 
379           "OpFunctionEnd",
380           {
381             { // defs
382               {1, "%1 = OpTypeVoid"},
383               {2, "%2 = OpFunction %1 None %3"},
384               {3, "%3 = OpTypeFunction %1"},
385               {4, "%4 = OpLabel"},
386               {5, "%5 = OpLabel"},
387               {6, "%6 = OpLabel"},
388               {7, "%7 = OpLabel"},
389               {8, "%8 = OpConstantTrue %10"},
390               {9, "%9 = OpLabel"},
391               {10, "%10 = OpTypeBool"},
392             },
393             { // uses
394               {1,
395                 {
396                   "%2 = OpFunction %1 None %3",
397                   "%3 = OpTypeFunction %1",
398                 }
399               },
400               {3, {"%2 = OpFunction %1 None %3"}},
401               {4, {"OpLoopMerge %5 %4 None"}},
402               {5, {"OpLoopMerge %5 %4 None"}},
403               {6, {"OpBranch %6"}},
404               {7,
405                 {
406                   "OpSelectionMerge %7 None",
407                   "OpBranchConditional %8 %9 %7",
408                 }
409               },
410               {8, {"OpBranchConditional %8 %9 %7"}},
411               {9, {"OpBranchConditional %8 %9 %7"}},
412               {10, {"%8 = OpConstantTrue %10"}},
413             }
414           }
415         },
416         { // Forward reference
417           "OpDecorate %1 Block "
418           "OpTypeForwardPointer %2 Input "
419           "%3 = OpTypeInt 32 0 "
420           "%1 = OpTypeStruct %3 "
421           "%2 = OpTypePointer Input %3",
422           {
423             { // defs
424               {1, "%1 = OpTypeStruct %3"},
425               {2, "%2 = OpTypePointer Input %3"},
426               {3, "%3 = OpTypeInt 32 0"},
427             },
428             { // uses
429               {1, {"OpDecorate %1 Block"}},
430               {2, {"OpTypeForwardPointer %2 Input"}},
431               {3,
432                 {
433                   "%1 = OpTypeStruct %3",
434                   "%2 = OpTypePointer Input %3",
435                 }
436               }
437             },
438           },
439         },
440         { // OpPhi
441           kOpPhiTestFunction,
442           {
443             { // defs
444               {1, "%1 = OpTypeVoid"},
445               {2, "%2 = OpFunction %1 None %3"},
446               {3, "%3 = OpTypeFunction %1"},
447               {4, "%4 = OpLabel"},
448               {5, "%5 = OpLabel"},
449               {6, "%6 = OpTypeInt 32 0"},
450               {7, "%7 = OpPhi %6 %8 %4 %9 %5"},
451               {8, "%8 = OpConstant %6 0"},
452               {9, "%9 = OpIAdd %6 %7 %8"},
453               {10, "%10 = OpTypeFloat 32"},
454               {11, "%11 = OpPhi %10 %12 %4 %13 %5"},
455               {12, "%12 = OpConstant %10 1.0"},
456               {13, "%13 = OpFAdd %10 %11 %12"},
457               {16, "%16 = OpTypeBool"},
458               {17, "%17 = OpSLessThan %16 %7 %18"},
459               {18, "%18 = OpConstant %6 1"},
460               {19, "%19 = OpLabel"},
461             },
462             { // uses
463               {1,
464                 {
465                   "%2 = OpFunction %1 None %3",
466                   "%3 = OpTypeFunction %1",
467                 }
468               },
469               {3, {"%2 = OpFunction %1 None %3"}},
470               {4,
471                 {
472                   "%7 = OpPhi %6 %8 %4 %9 %5",
473                   "%11 = OpPhi %10 %12 %4 %13 %5",
474                 }
475               },
476               {5,
477                 {
478                   "OpBranch %5",
479                   "%7 = OpPhi %6 %8 %4 %9 %5",
480                   "%11 = OpPhi %10 %12 %4 %13 %5",
481                   "OpLoopMerge %19 %5 None",
482                   "OpBranchConditional %17 %5 %19",
483                 }
484               },
485               {6,
486                 {
487                   // Can't check constants properly
488                   // "%8 = OpConstant %6 0",
489                   // "%18 = OpConstant %6 1",
490                   "%7 = OpPhi %6 %8 %4 %9 %5",
491                   "%9 = OpIAdd %6 %7 %8",
492                 }
493               },
494               {7,
495                 {
496                   "%9 = OpIAdd %6 %7 %8",
497                   "%17 = OpSLessThan %16 %7 %18",
498                 }
499               },
500               {8,
501                 {
502                   "%7 = OpPhi %6 %8 %4 %9 %5",
503                   "%9 = OpIAdd %6 %7 %8",
504                 }
505               },
506               {9, {"%7 = OpPhi %6 %8 %4 %9 %5"}},
507               {10,
508                 {
509                   // "%12 = OpConstant %10 1.0",
510                   "%11 = OpPhi %10 %12 %4 %13 %5",
511                   "%13 = OpFAdd %10 %11 %12",
512                 }
513               },
514               {11, {"%13 = OpFAdd %10 %11 %12"}},
515               {12,
516                 {
517                   "%11 = OpPhi %10 %12 %4 %13 %5",
518                   "%13 = OpFAdd %10 %11 %12",
519                 }
520               },
521               {13, {"%11 = OpPhi %10 %12 %4 %13 %5"}},
522               {16, {"%17 = OpSLessThan %16 %7 %18"}},
523               {17, {"OpBranchConditional %17 %5 %19"}},
524               {18, {"%17 = OpSLessThan %16 %7 %18"}},
525               {19,
526                 {
527                   "OpLoopMerge %19 %5 None",
528                   "OpBranchConditional %17 %5 %19",
529                 }
530               },
531             },
532           },
533         },
534         { // OpPhi defining and referencing the same id.
535           "%1 = OpTypeBool "
536           "%3 = OpTypeFunction %1 "
537           "%2 = OpConstantTrue %1 "
538           "%4 = OpFunction %1 None %3 "
539           "%6 = OpLabel "
540           "     OpBranch %7 "
541           "%7 = OpLabel "
542           "%8 = OpPhi %1   %8 %7   %2 %6 " // both defines and uses %8
543           "     OpBranch %7 "
544           "     OpFunctionEnd",
545           {
546             { // defs
547               {1, "%1 = OpTypeBool"},
548               {2, "%2 = OpConstantTrue %1"},
549               {3, "%3 = OpTypeFunction %1"},
550               {4, "%4 = OpFunction %1 None %3"},
551               {6, "%6 = OpLabel"},
552               {7, "%7 = OpLabel"},
553               {8, "%8 = OpPhi %1 %8 %7 %2 %6"},
554             },
555             { // uses
556               {1,
557                 {
558                   "%2 = OpConstantTrue %1",
559                   "%3 = OpTypeFunction %1",
560                   "%4 = OpFunction %1 None %3",
561                   "%8 = OpPhi %1 %8 %7 %2 %6",
562                 }
563               },
564               {2, {"%8 = OpPhi %1 %8 %7 %2 %6"}},
565               {3, {"%4 = OpFunction %1 None %3"}},
566               {6, {"%8 = OpPhi %1 %8 %7 %2 %6"}},
567               {7,
568                 {
569                   "OpBranch %7",
570                   "%8 = OpPhi %1 %8 %7 %2 %6",
571                   "OpBranch %7",
572                 }
573               },
574               {8, {"%8 = OpPhi %1 %8 %7 %2 %6"}},
575             },
576           },
577         },
578     })
579 );
580 // clang-format on
581 
582 struct ReplaceUseCase {
583   const char* before;
584   std::vector<std::pair<uint32_t, uint32_t>> candidates;
585   const char* after;
586   InstDefUse du;
587 };
588 
589 using ReplaceUseTest = ::testing::TestWithParam<ReplaceUseCase>;
590 
591 // Disassembles the given |module| and returns the disassembly.
DisassembleModule(Module * module)592 std::string DisassembleModule(Module* module) {
593   SpirvTools tools(SPV_ENV_UNIVERSAL_1_1);
594 
595   std::vector<uint32_t> binary;
596   module->ToBinary(&binary, /* skip_nop = */ false);
597 
598   std::string text;
599   // We'll need to check the underlying id numbers.
600   // So turn off friendly names for ids.
601   tools.Disassemble(binary, &text, SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
602   while (!text.empty() && text.back() == '\n') text.pop_back();
603   return text;
604 }
605 
TEST_P(ReplaceUseTest,Case)606 TEST_P(ReplaceUseTest, Case) {
607   const auto& tc = GetParam();
608 
609   // Build module.
610   const std::vector<const char*> text = {tc.before};
611   std::unique_ptr<IRContext> context =
612       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text),
613                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
614   ASSERT_NE(nullptr, context);
615 
616   // Force a re-build of def-use manager.
617   context->InvalidateAnalyses(IRContext::Analysis::kAnalysisDefUse);
618   (void)context->get_def_use_mgr();
619 
620   // Do the substitution.
621   for (const auto& candidate : tc.candidates) {
622     context->ReplaceAllUsesWith(candidate.first, candidate.second);
623   }
624 
625   EXPECT_EQ(tc.after, DisassembleModule(context->module()));
626   CheckDef(tc.du, context->get_def_use_mgr()->id_to_defs());
627   CheckUse(tc.du, context->get_def_use_mgr(), context->module()->IdBound());
628 }
629 
630 // clang-format off
631 INSTANTIATE_TEST_SUITE_P(
632     TestCase, ReplaceUseTest,
633     ::testing::ValuesIn(std::vector<ReplaceUseCase>{
634       { // no use, no replace request
635         "", {}, "", {},
636       },
637       { // replace one use
638         "%1 = OpTypeBool "
639         "%2 = OpTypeVector %1 3 "
640         "%3 = OpTypeInt 32 0 ",
641         {{1, 3}},
642         "%1 = OpTypeBool\n"
643         "%2 = OpTypeVector %3 3\n"
644         "%3 = OpTypeInt 32 0",
645         {
646           { // defs
647             {1, "%1 = OpTypeBool"},
648             {2, "%2 = OpTypeVector %3 3"},
649             {3, "%3 = OpTypeInt 32 0"},
650           },
651           { // uses
652             {3, {"%2 = OpTypeVector %3 3"}},
653           },
654         },
655       },
656       { // replace and then replace back
657         "%1 = OpTypeBool "
658         "%2 = OpTypeVector %1 3 "
659         "%3 = OpTypeInt 32 0",
660         {{1, 3}, {3, 1}},
661         "%1 = OpTypeBool\n"
662         "%2 = OpTypeVector %1 3\n"
663         "%3 = OpTypeInt 32 0",
664         {
665           { // defs
666             {1, "%1 = OpTypeBool"},
667             {2, "%2 = OpTypeVector %1 3"},
668             {3, "%3 = OpTypeInt 32 0"},
669           },
670           { // uses
671             {1, {"%2 = OpTypeVector %1 3"}},
672           },
673         },
674       },
675       { // replace with the same id
676         "%1 = OpTypeBool "
677         "%2 = OpTypeVector %1 3",
678         {{1, 1}, {2, 2}, {3, 3}},
679         "%1 = OpTypeBool\n"
680         "%2 = OpTypeVector %1 3",
681         {
682           { // defs
683             {1, "%1 = OpTypeBool"},
684             {2, "%2 = OpTypeVector %1 3"},
685           },
686           { // uses
687             {1, {"%2 = OpTypeVector %1 3"}},
688           },
689         },
690       },
691       { // replace in sequence
692         "%1 = OpTypeBool "
693         "%2 = OpTypeVector %1 3 "
694         "%3 = OpTypeInt 32 0 "
695         "%4 = OpTypeInt 32 1 ",
696         {{1, 3}, {3, 4}},
697         "%1 = OpTypeBool\n"
698         "%2 = OpTypeVector %4 3\n"
699         "%3 = OpTypeInt 32 0\n"
700         "%4 = OpTypeInt 32 1",
701         {
702           { // defs
703             {1, "%1 = OpTypeBool"},
704             {2, "%2 = OpTypeVector %4 3"},
705             {3, "%3 = OpTypeInt 32 0"},
706             {4, "%4 = OpTypeInt 32 1"},
707           },
708           { // uses
709             {4, {"%2 = OpTypeVector %4 3"}},
710           },
711         },
712       },
713       { // replace multiple uses
714         "%1 = OpTypeBool "
715         "%2 = OpTypeVector %1 2 "
716         "%3 = OpTypeVector %1 3 "
717         "%4 = OpTypeVector %1 4 "
718         "%5 = OpTypeMatrix %2 2 "
719         "%6 = OpTypeMatrix %3 3 "
720         "%7 = OpTypeMatrix %4 4 "
721         "%8 = OpTypeInt 32 0 "
722         "%9 = OpTypeInt 32 1 "
723         "%10 = OpTypeInt 64 0",
724         {{1, 8}, {2, 9}, {4, 10}},
725         "%1 = OpTypeBool\n"
726         "%2 = OpTypeVector %8 2\n"
727         "%3 = OpTypeVector %8 3\n"
728         "%4 = OpTypeVector %8 4\n"
729         "%5 = OpTypeMatrix %9 2\n"
730         "%6 = OpTypeMatrix %3 3\n"
731         "%7 = OpTypeMatrix %10 4\n"
732         "%8 = OpTypeInt 32 0\n"
733         "%9 = OpTypeInt 32 1\n"
734         "%10 = OpTypeInt 64 0",
735         {
736           { // defs
737             {1, "%1 = OpTypeBool"},
738             {2, "%2 = OpTypeVector %8 2"},
739             {3, "%3 = OpTypeVector %8 3"},
740             {4, "%4 = OpTypeVector %8 4"},
741             {5, "%5 = OpTypeMatrix %9 2"},
742             {6, "%6 = OpTypeMatrix %3 3"},
743             {7, "%7 = OpTypeMatrix %10 4"},
744             {8, "%8 = OpTypeInt 32 0"},
745             {9, "%9 = OpTypeInt 32 1"},
746             {10, "%10 = OpTypeInt 64 0"},
747           },
748           { // uses
749             {8,
750               {
751                 "%2 = OpTypeVector %8 2",
752                 "%3 = OpTypeVector %8 3",
753                 "%4 = OpTypeVector %8 4",
754               }
755             },
756             {9, {"%5 = OpTypeMatrix %9 2"}},
757             {3, {"%6 = OpTypeMatrix %3 3"}},
758             {10, {"%7 = OpTypeMatrix %10 4"}},
759           },
760         },
761       },
762       { // OpPhi.
763         kOpPhiTestFunction,
764         // replace one id used by OpPhi, replace one id generated by OpPhi
765         {{9, 13}, {11, 9}},
766          "%1 = OpTypeVoid\n"
767          "%6 = OpTypeInt 32 0\n"
768          "%10 = OpTypeFloat 32\n"
769          "%16 = OpTypeBool\n"
770          "%3 = OpTypeFunction %1\n"
771          "%8 = OpConstant %6 0\n"
772          "%18 = OpConstant %6 1\n"
773          "%12 = OpConstant %10 1\n"
774          "%2 = OpFunction %1 None %3\n"
775          "%4 = OpLabel\n"
776                "OpBranch %5\n"
777 
778          "%5 = OpLabel\n"
779          "%7 = OpPhi %6 %8 %4 %13 %5\n" // %9 -> %13
780         "%11 = OpPhi %10 %12 %4 %13 %5\n"
781          "%9 = OpIAdd %6 %7 %8\n"
782         "%13 = OpFAdd %10 %9 %12\n"       // %11 -> %9
783         "%17 = OpSLessThan %16 %7 %18\n"
784               "OpLoopMerge %19 %5 None\n"
785               "OpBranchConditional %17 %5 %19\n"
786 
787         "%19 = OpLabel\n"
788               "OpReturn\n"
789               "OpFunctionEnd",
790         {
791           { // defs.
792             {1, "%1 = OpTypeVoid"},
793             {2, "%2 = OpFunction %1 None %3"},
794             {3, "%3 = OpTypeFunction %1"},
795             {4, "%4 = OpLabel"},
796             {5, "%5 = OpLabel"},
797             {6, "%6 = OpTypeInt 32 0"},
798             {7, "%7 = OpPhi %6 %8 %4 %13 %5"},
799             {8, "%8 = OpConstant %6 0"},
800             {9, "%9 = OpIAdd %6 %7 %8"},
801             {10, "%10 = OpTypeFloat 32"},
802             {11, "%11 = OpPhi %10 %12 %4 %13 %5"},
803             {12, "%12 = OpConstant %10 1.0"},
804             {13, "%13 = OpFAdd %10 %9 %12"},
805             {16, "%16 = OpTypeBool"},
806             {17, "%17 = OpSLessThan %16 %7 %18"},
807             {18, "%18 = OpConstant %6 1"},
808             {19, "%19 = OpLabel"},
809           },
810           { // uses
811             {1,
812               {
813                 "%2 = OpFunction %1 None %3",
814                 "%3 = OpTypeFunction %1",
815               }
816             },
817             {3, {"%2 = OpFunction %1 None %3"}},
818             {4,
819               {
820                 "%7 = OpPhi %6 %8 %4 %13 %5",
821                 "%11 = OpPhi %10 %12 %4 %13 %5",
822               }
823             },
824             {5,
825               {
826                 "OpBranch %5",
827                 "%7 = OpPhi %6 %8 %4 %13 %5",
828                 "%11 = OpPhi %10 %12 %4 %13 %5",
829                 "OpLoopMerge %19 %5 None",
830                 "OpBranchConditional %17 %5 %19",
831               }
832             },
833             {6,
834               {
835                 // Can't properly check constants
836                 // "%8 = OpConstant %6 0",
837                 // "%18 = OpConstant %6 1",
838                 "%7 = OpPhi %6 %8 %4 %13 %5",
839                 "%9 = OpIAdd %6 %7 %8"
840               }
841             },
842             {7,
843               {
844                 "%9 = OpIAdd %6 %7 %8",
845                 "%17 = OpSLessThan %16 %7 %18",
846               }
847             },
848             {8,
849               {
850                 "%7 = OpPhi %6 %8 %4 %13 %5",
851                 "%9 = OpIAdd %6 %7 %8",
852               }
853             },
854             {9, {"%13 = OpFAdd %10 %9 %12"}}, // uses of %9 changed from %7 to %13
855             {10,
856               {
857                 "%11 = OpPhi %10 %12 %4 %13 %5",
858                 // "%12 = OpConstant %10 1",
859                 "%13 = OpFAdd %10 %9 %12"
860               }
861             },
862             // no more uses of %11
863             {12,
864               {
865                 "%11 = OpPhi %10 %12 %4 %13 %5",
866                 "%13 = OpFAdd %10 %9 %12"
867               }
868             },
869             {13, {
870                    "%7 = OpPhi %6 %8 %4 %13 %5",
871                    "%11 = OpPhi %10 %12 %4 %13 %5",
872                  }
873             },
874             {16, {"%17 = OpSLessThan %16 %7 %18"}},
875             {17, {"OpBranchConditional %17 %5 %19"}},
876             {18, {"%17 = OpSLessThan %16 %7 %18"}},
877             {19,
878               {
879                 "OpLoopMerge %19 %5 None",
880                 "OpBranchConditional %17 %5 %19",
881               }
882             },
883           },
884         },
885       },
886       { // OpPhi defining and referencing the same id.
887         "%1 = OpTypeBool "
888         "%3 = OpTypeFunction %1 "
889         "%2 = OpConstantTrue %1 "
890 
891         "%4 = OpFunction %3 None %1 "
892         "%6 = OpLabel "
893         "     OpBranch %7 "
894         "%7 = OpLabel "
895         "%8 = OpPhi %1   %8 %7   %2 %6 " // both defines and uses %8
896         "     OpBranch %7 "
897         "     OpFunctionEnd",
898         {{8, 2}},
899         "%1 = OpTypeBool\n"
900         "%3 = OpTypeFunction %1\n"
901         "%2 = OpConstantTrue %1\n"
902 
903         "%4 = OpFunction %3 None %1\n"
904         "%6 = OpLabel\n"
905              "OpBranch %7\n"
906         "%7 = OpLabel\n"
907         "%8 = OpPhi %1 %2 %7 %2 %6\n" // use of %8 changed to %2
908              "OpBranch %7\n"
909              "OpFunctionEnd",
910         {
911           { // defs
912             {1, "%1 = OpTypeBool"},
913             {2, "%2 = OpConstantTrue %1"},
914             {3, "%3 = OpTypeFunction %1"},
915             {4, "%4 = OpFunction %3 None %1"},
916             {6, "%6 = OpLabel"},
917             {7, "%7 = OpLabel"},
918             {8, "%8 = OpPhi %1 %2 %7 %2 %6"},
919           },
920           { // uses
921             {1,
922               {
923                 "%2 = OpConstantTrue %1",
924                 "%3 = OpTypeFunction %1",
925                 "%4 = OpFunction %3 None %1",
926                 "%8 = OpPhi %1 %2 %7 %2 %6",
927               }
928             },
929             {2,
930               {
931                 // Only checking users
932                 "%8 = OpPhi %1 %2 %7 %2 %6",
933               }
934             },
935             {3, {"%4 = OpFunction %3 None %1"}},
936             {6, {"%8 = OpPhi %1 %2 %7 %2 %6"}},
937             {7,
938               {
939                 "OpBranch %7",
940                 "%8 = OpPhi %1 %2 %7 %2 %6",
941                 "OpBranch %7",
942               }
943             },
944             // {8, {"%8 = OpPhi %1 %8 %7 %2 %6"}},
945           },
946         },
947       },
948     })
949 );
950 // clang-format on
951 
952 struct KillDefCase {
953   const char* before;
954   std::vector<uint32_t> ids_to_kill;
955   const char* after;
956   InstDefUse du;
957 };
958 
959 using KillDefTest = ::testing::TestWithParam<KillDefCase>;
960 
TEST_P(KillDefTest,Case)961 TEST_P(KillDefTest, Case) {
962   const auto& tc = GetParam();
963 
964   // Build module.
965   const std::vector<const char*> text = {tc.before};
966   std::unique_ptr<IRContext> context =
967       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text),
968                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
969   ASSERT_NE(nullptr, context);
970 
971   // Analyze def and use.
972   DefUseManager manager(context->module());
973 
974   // Do the substitution.
975   for (const auto id : tc.ids_to_kill) context->KillDef(id);
976 
977   EXPECT_EQ(tc.after, DisassembleModule(context->module()));
978   CheckDef(tc.du, context->get_def_use_mgr()->id_to_defs());
979   CheckUse(tc.du, context->get_def_use_mgr(), context->module()->IdBound());
980 }
981 
982 // clang-format off
983 INSTANTIATE_TEST_SUITE_P(
984     TestCase, KillDefTest,
985     ::testing::ValuesIn(std::vector<KillDefCase>{
986       { // no def, no use, no kill
987         "", {}, "", {}
988       },
989       { // kill nothing
990         "%1 = OpTypeBool "
991         "%2 = OpTypeVector %1 2 "
992         "%3 = OpTypeVector %1 3 ",
993         {},
994         "%1 = OpTypeBool\n"
995         "%2 = OpTypeVector %1 2\n"
996         "%3 = OpTypeVector %1 3",
997         {
998           { // defs
999             {1, "%1 = OpTypeBool"},
1000             {2, "%2 = OpTypeVector %1 2"},
1001             {3, "%3 = OpTypeVector %1 3"},
1002           },
1003           { // uses
1004             {1,
1005               {
1006                 "%2 = OpTypeVector %1 2",
1007                 "%3 = OpTypeVector %1 3",
1008               }
1009             },
1010           },
1011         },
1012       },
1013       { // kill id used, kill id not used, kill id not defined
1014         "%1 = OpTypeBool "
1015         "%2 = OpTypeVector %1 2 "
1016         "%3 = OpTypeVector %1 3 "
1017         "%4 = OpTypeVector %1 4 "
1018         "%5 = OpTypeMatrix %3 3 "
1019         "%6 = OpTypeMatrix %2 3",
1020         {1, 3, 5, 10}, // ids to kill
1021         "%2 = OpTypeVector %1 2\n"
1022         "%4 = OpTypeVector %1 4\n"
1023         "%6 = OpTypeMatrix %2 3",
1024         {
1025           { // defs
1026             {2, "%2 = OpTypeVector %1 2"},
1027             {4, "%4 = OpTypeVector %1 4"},
1028             {6, "%6 = OpTypeMatrix %2 3"},
1029           },
1030           { // uses. %1 and %3 are both killed, so no uses
1031             // recorded for them anymore.
1032             {2, {"%6 = OpTypeMatrix %2 3"}},
1033           }
1034         },
1035       },
1036       { // OpPhi.
1037         kOpPhiTestFunction,
1038         {9, 11}, // kill one id used by OpPhi, kill one id generated by OpPhi
1039          "%1 = OpTypeVoid\n"
1040          "%6 = OpTypeInt 32 0\n"
1041          "%10 = OpTypeFloat 32\n"
1042          "%16 = OpTypeBool\n"
1043          "%3 = OpTypeFunction %1\n"
1044          "%8 = OpConstant %6 0\n"
1045          "%18 = OpConstant %6 1\n"
1046          "%12 = OpConstant %10 1\n"
1047          "%2 = OpFunction %1 None %3\n"
1048          "%4 = OpLabel\n"
1049                "OpBranch %5\n"
1050 
1051          "%5 = OpLabel\n"
1052          "%7 = OpPhi %6 %8 %4 %9 %5\n"
1053         "%13 = OpFAdd %10 %11 %12\n"
1054         "%17 = OpSLessThan %16 %7 %18\n"
1055               "OpLoopMerge %19 %5 None\n"
1056               "OpBranchConditional %17 %5 %19\n"
1057 
1058         "%19 = OpLabel\n"
1059               "OpReturn\n"
1060               "OpFunctionEnd",
1061         {
1062           { // defs. %9 & %11 are killed.
1063             {1, "%1 = OpTypeVoid"},
1064             {2, "%2 = OpFunction %1 None %3"},
1065             {3, "%3 = OpTypeFunction %1"},
1066             {4, "%4 = OpLabel"},
1067             {5, "%5 = OpLabel"},
1068             {6, "%6 = OpTypeInt 32 0"},
1069             {7, "%7 = OpPhi %6 %8 %4 %9 %5"},
1070             {8, "%8 = OpConstant %6 0"},
1071             {10, "%10 = OpTypeFloat 32"},
1072             {12, "%12 = OpConstant %10 1.0"},
1073             {13, "%13 = OpFAdd %10 %11 %12"},
1074             {16, "%16 = OpTypeBool"},
1075             {17, "%17 = OpSLessThan %16 %7 %18"},
1076             {18, "%18 = OpConstant %6 1"},
1077             {19, "%19 = OpLabel"},
1078           },
1079           { // uses
1080             {1,
1081               {
1082                 "%2 = OpFunction %1 None %3",
1083                 "%3 = OpTypeFunction %1",
1084               }
1085             },
1086             {3, {"%2 = OpFunction %1 None %3"}},
1087             {4,
1088               {
1089                 "%7 = OpPhi %6 %8 %4 %9 %5",
1090                 // "%11 = OpPhi %10 %12 %4 %13 %5",
1091               }
1092             },
1093             {5,
1094               {
1095                 "OpBranch %5",
1096                 "%7 = OpPhi %6 %8 %4 %9 %5",
1097                 // "%11 = OpPhi %10 %12 %4 %13 %5",
1098                 "OpLoopMerge %19 %5 None",
1099                 "OpBranchConditional %17 %5 %19",
1100               }
1101             },
1102             {6,
1103               {
1104                 // Can't properly check constants
1105                 // "%8 = OpConstant %6 0",
1106                 // "%18 = OpConstant %6 1",
1107                 "%7 = OpPhi %6 %8 %4 %9 %5",
1108                 // "%9 = OpIAdd %6 %7 %8"
1109               }
1110             },
1111             {7, {"%17 = OpSLessThan %16 %7 %18"}},
1112             {8,
1113               {
1114                 "%7 = OpPhi %6 %8 %4 %9 %5",
1115                 // "%9 = OpIAdd %6 %7 %8",
1116               }
1117             },
1118             // {9, {"%7 = OpPhi %6 %8 %4 %13 %5"}},
1119             {10,
1120               {
1121                 // "%11 = OpPhi %10 %12 %4 %13 %5",
1122                 // "%12 = OpConstant %10 1",
1123                 "%13 = OpFAdd %10 %11 %12"
1124               }
1125             },
1126             // {11, {"%13 = OpFAdd %10 %11 %12"}},
1127             {12,
1128               {
1129                 // "%11 = OpPhi %10 %12 %4 %13 %5",
1130                 "%13 = OpFAdd %10 %11 %12"
1131               }
1132             },
1133             // {13, {"%11 = OpPhi %10 %12 %4 %13 %5"}},
1134             {16, {"%17 = OpSLessThan %16 %7 %18"}},
1135             {17, {"OpBranchConditional %17 %5 %19"}},
1136             {18, {"%17 = OpSLessThan %16 %7 %18"}},
1137             {19,
1138               {
1139                 "OpLoopMerge %19 %5 None",
1140                 "OpBranchConditional %17 %5 %19",
1141               }
1142             },
1143           },
1144         },
1145       },
1146       { // OpPhi defining and referencing the same id.
1147         "%1 = OpTypeBool "
1148         "%3 = OpTypeFunction %1 "
1149         "%2 = OpConstantTrue %1 "
1150         "%4 = OpFunction %3 None %1 "
1151         "%6 = OpLabel "
1152         "     OpBranch %7 "
1153         "%7 = OpLabel "
1154         "%8 = OpPhi %1   %8 %7   %2 %6 " // both defines and uses %8
1155         "     OpBranch %7 "
1156         "     OpFunctionEnd",
1157         {8},
1158         "%1 = OpTypeBool\n"
1159         "%3 = OpTypeFunction %1\n"
1160         "%2 = OpConstantTrue %1\n"
1161 
1162         "%4 = OpFunction %3 None %1\n"
1163         "%6 = OpLabel\n"
1164              "OpBranch %7\n"
1165         "%7 = OpLabel\n"
1166              "OpBranch %7\n"
1167              "OpFunctionEnd",
1168         {
1169           { // defs
1170             {1, "%1 = OpTypeBool"},
1171             {2, "%2 = OpConstantTrue %1"},
1172             {3, "%3 = OpTypeFunction %1"},
1173             {4, "%4 = OpFunction %3 None %1"},
1174             {6, "%6 = OpLabel"},
1175             {7, "%7 = OpLabel"},
1176             // {8, "%8 = OpPhi %1 %8 %7 %2 %6"},
1177           },
1178           { // uses
1179             {1,
1180               {
1181                 "%2 = OpConstantTrue %1",
1182                 "%3 = OpTypeFunction %1",
1183                 "%4 = OpFunction %3 None %1",
1184                 // "%8 = OpPhi %1 %8 %7 %2 %6",
1185               }
1186             },
1187             // {2, {"%8 = OpPhi %1 %8 %7 %2 %6"}},
1188             {3, {"%4 = OpFunction %3 None %1"}},
1189             // {6, {"%8 = OpPhi %1 %8 %7 %2 %6"}},
1190             {7,
1191               {
1192                 "OpBranch %7",
1193                 // "%8 = OpPhi %1 %8 %7 %2 %6",
1194                 "OpBranch %7",
1195               }
1196             },
1197             // {8, {"%8 = OpPhi %1 %8 %7 %2 %6"}},
1198           },
1199         },
1200       },
1201     })
1202 );
1203 // clang-format on
1204 
TEST(DefUseTest,OpSwitch)1205 TEST(DefUseTest, OpSwitch) {
1206   // Because disassembler has basic type check for OpSwitch's selector, we
1207   // cannot use the DisassembleInst() in the above. Thus, this special spotcheck
1208   // test case.
1209 
1210   const char original_text[] =
1211       // int64 f(int64 v) {
1212       //   switch (v) {
1213       //     case 1:                   break;
1214       //     case -4294967296:         break;
1215       //     case 9223372036854775807: break;
1216       //     default:                  break;
1217       //   }
1218       //   return v;
1219       // }
1220       " %1 = OpTypeInt 64 1 "
1221       " %3 = OpTypePointer Input %1 "
1222       " %2 = OpFunction %1 None %3 "  // %3 is int64(int64)*
1223       " %4 = OpFunctionParameter %1 "
1224       " %5 = OpLabel "
1225       " %6 = OpLoad %1 %4 "  // selector value
1226       "      OpSelectionMerge %7 None "
1227       "      OpSwitch %6 %8 "
1228       "                  1                    %9 "  // 1
1229       "                  -4294967296         %10 "  // -2^32
1230       "                  9223372036854775807 %11 "  // 2^63-1
1231       " %8 = OpLabel "                              // default
1232       "      OpBranch %7 "
1233       " %9 = OpLabel "
1234       "      OpBranch %7 "
1235       "%10 = OpLabel "
1236       "      OpBranch %7 "
1237       "%11 = OpLabel "
1238       "      OpBranch %7 "
1239       " %7 = OpLabel "
1240       "      OpReturnValue %6 "
1241       "      OpFunctionEnd";
1242 
1243   std::unique_ptr<IRContext> context =
1244       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, original_text,
1245                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
1246   ASSERT_NE(nullptr, context);
1247 
1248   // Force a re-build of def-use manager.
1249   context->InvalidateAnalyses(IRContext::Analysis::kAnalysisDefUse);
1250   (void)context->get_def_use_mgr();
1251 
1252   // Do a bunch replacements.
1253   context->ReplaceAllUsesWith(11, 7);   // to existing id
1254   context->ReplaceAllUsesWith(10, 11);  // to existing id
1255   context->ReplaceAllUsesWith(9, 10);   // to existing id
1256 
1257   // clang-format off
1258   const char modified_text[] =
1259        "%1 = OpTypeInt 64 1\n"
1260        "%3 = OpTypePointer Input %1\n"
1261        "%2 = OpFunction %1 None %3\n" // %3 is int64(int64)*
1262        "%4 = OpFunctionParameter %1\n"
1263        "%5 = OpLabel\n"
1264        "%6 = OpLoad %1 %4\n" // selector value
1265             "OpSelectionMerge %7 None\n"
1266             "OpSwitch %6 %8 1 %10 -4294967296 %11 9223372036854775807 %7\n" // changed!
1267        "%8 = OpLabel\n"      // default
1268             "OpBranch %7\n"
1269        "%9 = OpLabel\n"
1270             "OpBranch %7\n"
1271       "%10 = OpLabel\n"
1272             "OpBranch %7\n"
1273       "%11 = OpLabel\n"
1274             "OpBranch %7\n"
1275        "%7 = OpLabel\n"
1276             "OpReturnValue %6\n"
1277             "OpFunctionEnd";
1278   // clang-format on
1279 
1280   EXPECT_EQ(modified_text, DisassembleModule(context->module()));
1281 
1282   InstDefUse def_uses = {};
1283   def_uses.defs = {
1284       {1, "%1 = OpTypeInt 64 1"},
1285       {2, "%2 = OpFunction %1 None %3"},
1286       {3, "%3 = OpTypePointer Input %1"},
1287       {4, "%4 = OpFunctionParameter %1"},
1288       {5, "%5 = OpLabel"},
1289       {6, "%6 = OpLoad %1 %4"},
1290       {7, "%7 = OpLabel"},
1291       {8, "%8 = OpLabel"},
1292       {9, "%9 = OpLabel"},
1293       {10, "%10 = OpLabel"},
1294       {11, "%11 = OpLabel"},
1295   };
1296   CheckDef(def_uses, context->get_def_use_mgr()->id_to_defs());
1297 
1298   {
1299     EXPECT_EQ(2u, NumUses(context, 6));
1300     std::vector<spv::Op> opcodes = GetUseOpcodes(context, 6u);
1301     EXPECT_THAT(opcodes, UnorderedElementsAre(spv::Op::OpSwitch,
1302                                               spv::Op::OpReturnValue));
1303   }
1304   {
1305     EXPECT_EQ(6u, NumUses(context, 7));
1306     std::vector<spv::Op> opcodes = GetUseOpcodes(context, 7u);
1307     // OpSwitch is now a user of %7.
1308     EXPECT_THAT(opcodes, UnorderedElementsAre(
1309                              spv::Op::OpSelectionMerge, spv::Op::OpBranch,
1310                              spv::Op::OpBranch, spv::Op::OpBranch,
1311                              spv::Op::OpBranch, spv::Op::OpSwitch));
1312   }
1313   // Check all ids only used by OpSwitch after replacement.
1314   for (const auto id : {8u, 10u, 11u}) {
1315     EXPECT_EQ(1u, NumUses(context, id));
1316     EXPECT_EQ(spv::Op::OpSwitch, GetUseOpcodes(context, id).back());
1317   }
1318 }
1319 
1320 // Test case for analyzing individual instructions.
1321 struct AnalyzeInstDefUseTestCase {
1322   const char* module_text;
1323   InstDefUse expected_define_use;
1324 };
1325 
1326 using AnalyzeInstDefUseTest =
1327     ::testing::TestWithParam<AnalyzeInstDefUseTestCase>;
1328 
1329 // Test the analyzing result for individual instructions.
TEST_P(AnalyzeInstDefUseTest,Case)1330 TEST_P(AnalyzeInstDefUseTest, Case) {
1331   auto tc = GetParam();
1332 
1333   // Build module.
1334   std::unique_ptr<IRContext> context =
1335       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.module_text);
1336   ASSERT_NE(nullptr, context);
1337 
1338   // Analyze the instructions.
1339   DefUseManager manager(context->module());
1340 
1341   CheckDef(tc.expected_define_use, manager.id_to_defs());
1342   CheckUse(tc.expected_define_use, &manager, context->module()->IdBound());
1343   // CheckUse(tc.expected_define_use, manager.id_to_uses());
1344 }
1345 
1346 // clang-format off
1347 INSTANTIATE_TEST_SUITE_P(
1348     TestCase, AnalyzeInstDefUseTest,
1349     ::testing::ValuesIn(std::vector<AnalyzeInstDefUseTestCase>{
1350       { // A type declaring instruction.
1351         "%1 = OpTypeInt 32 1",
1352         {
1353           // defs
1354           {{1, "%1 = OpTypeInt 32 1"}},
1355           {}, // no uses
1356         },
1357       },
1358       { // A type declaring instruction and a constant value.
1359         "%1 = OpTypeBool "
1360         "%2 = OpConstantTrue %1",
1361         {
1362           { // defs
1363             {1, "%1 = OpTypeBool"},
1364             {2, "%2 = OpConstantTrue %1"},
1365           },
1366           { // uses
1367             {1, {"%2 = OpConstantTrue %1"}},
1368           },
1369         },
1370       },
1371       }));
1372 // clang-format on
1373 
1374 using AnalyzeInstDefUse = ::testing::Test;
1375 
TEST(AnalyzeInstDefUse,UseWithNoResultId)1376 TEST(AnalyzeInstDefUse, UseWithNoResultId) {
1377   IRContext context(SPV_ENV_UNIVERSAL_1_2, nullptr);
1378 
1379   // Analyze the instructions.
1380   DefUseManager manager(context.module());
1381 
1382   Instruction label(&context, spv::Op::OpLabel, 0, 2, {});
1383   manager.AnalyzeInstDefUse(&label);
1384 
1385   Instruction branch(&context, spv::Op::OpBranch, 0, 0,
1386                      {{SPV_OPERAND_TYPE_ID, {2}}});
1387   manager.AnalyzeInstDefUse(&branch);
1388   context.module()->SetIdBound(3);
1389 
1390   InstDefUse expected = {
1391       // defs
1392       {
1393           {2, "%2 = OpLabel"},
1394       },
1395       // uses
1396       {{2, {"OpBranch %2"}}},
1397   };
1398 
1399   CheckDef(expected, manager.id_to_defs());
1400   CheckUse(expected, &manager, context.module()->IdBound());
1401 }
1402 
TEST(AnalyzeInstDefUse,AddNewInstruction)1403 TEST(AnalyzeInstDefUse, AddNewInstruction) {
1404   const std::string input = "%1 = OpTypeBool";
1405 
1406   // Build module.
1407   std::unique_ptr<IRContext> context =
1408       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, input);
1409   ASSERT_NE(nullptr, context);
1410 
1411   // Analyze the instructions.
1412   DefUseManager manager(context->module());
1413 
1414   Instruction newInst(context.get(), spv::Op::OpConstantTrue, 1, 2, {});
1415   manager.AnalyzeInstDefUse(&newInst);
1416 
1417   InstDefUse expected = {
1418       {
1419           // defs
1420           {1, "%1 = OpTypeBool"},
1421           {2, "%2 = OpConstantTrue %1"},
1422       },
1423       {
1424           // uses
1425           {1, {"%2 = OpConstantTrue %1"}},
1426       },
1427   };
1428 
1429   CheckDef(expected, manager.id_to_defs());
1430   CheckUse(expected, &manager, context->module()->IdBound());
1431 }
1432 
1433 struct KillInstTestCase {
1434   const char* before;
1435   std::unordered_set<uint32_t> indices_for_inst_to_kill;
1436   const char* after;
1437   InstDefUse expected_define_use;
1438 };
1439 
1440 using KillInstTest = ::testing::TestWithParam<KillInstTestCase>;
1441 
TEST_P(KillInstTest,Case)1442 TEST_P(KillInstTest, Case) {
1443   auto tc = GetParam();
1444 
1445   // Build module.
1446   std::unique_ptr<IRContext> context =
1447       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.before,
1448                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
1449   ASSERT_NE(nullptr, context);
1450 
1451   // Force a re-build of the def-use manager.
1452   context->InvalidateAnalyses(IRContext::Analysis::kAnalysisDefUse);
1453   (void)context->get_def_use_mgr();
1454 
1455   // KillInst
1456   context->module()->ForEachInst([&tc, &context](Instruction* inst) {
1457     if (tc.indices_for_inst_to_kill.count(inst->result_id())) {
1458       context->KillInst(inst);
1459     }
1460   });
1461 
1462   EXPECT_EQ(tc.after, DisassembleModule(context->module()));
1463   CheckDef(tc.expected_define_use, context->get_def_use_mgr()->id_to_defs());
1464   CheckUse(tc.expected_define_use, context->get_def_use_mgr(),
1465            context->module()->IdBound());
1466 }
1467 
1468 // clang-format off
1469 INSTANTIATE_TEST_SUITE_P(
1470     TestCase, KillInstTest,
1471     ::testing::ValuesIn(std::vector<KillInstTestCase>{
1472       // Kill id defining instructions.
1473       {
1474         "%3 = OpTypeVoid "
1475         "%1 = OpTypeFunction %3 "
1476         "%2 = OpFunction %1 None %3 "
1477         "%4 = OpLabel "
1478         "     OpBranch %5 "
1479         "%5 = OpLabel "
1480         "     OpBranch %6 "
1481         "%6 = OpLabel "
1482         "     OpBranch %4 "
1483         "%7 = OpLabel "
1484         "     OpReturn "
1485         "     OpFunctionEnd",
1486         {3, 5, 7},
1487         "%1 = OpTypeFunction %3\n"
1488         "%2 = OpFunction %1 None %3\n"
1489         "%4 = OpLabel\n"
1490         "OpBranch %5\n"
1491         "OpNop\n"
1492         "OpBranch %6\n"
1493         "%6 = OpLabel\n"
1494         "OpBranch %4\n"
1495         "OpNop\n"
1496         "OpReturn\n"
1497         "OpFunctionEnd",
1498         {
1499           // defs
1500           {
1501             {1, "%1 = OpTypeFunction %3"},
1502             {2, "%2 = OpFunction %1 None %3"},
1503             {4, "%4 = OpLabel"},
1504             {6, "%6 = OpLabel"},
1505           },
1506           // uses
1507           {
1508             {1, {"%2 = OpFunction %1 None %3"}},
1509             {4, {"OpBranch %4"}},
1510             {6, {"OpBranch %6"}},
1511           }
1512         }
1513       },
1514       // Kill instructions that do not have result ids.
1515       {
1516         "%3 = OpTypeVoid "
1517         "%1 = OpTypeFunction %3 "
1518         "%2 = OpFunction %1 None %3 "
1519         "%4 = OpLabel "
1520         "     OpBranch %5 "
1521         "%5 = OpLabel "
1522         "     OpBranch %6 "
1523         "%6 = OpLabel "
1524         "     OpBranch %4 "
1525         "%7 = OpLabel "
1526         "     OpReturn "
1527         "     OpFunctionEnd",
1528         {2, 4},
1529         "%3 = OpTypeVoid\n"
1530         "%1 = OpTypeFunction %3\n"
1531              "OpNop\n"
1532              "OpNop\n"
1533              "OpBranch %5\n"
1534         "%5 = OpLabel\n"
1535              "OpBranch %6\n"
1536         "%6 = OpLabel\n"
1537              "OpBranch %4\n"
1538         "%7 = OpLabel\n"
1539              "OpReturn\n"
1540              "OpFunctionEnd",
1541         {
1542           // defs
1543           {
1544             {1, "%1 = OpTypeFunction %3"},
1545             {3, "%3 = OpTypeVoid"},
1546             {5, "%5 = OpLabel"},
1547             {6, "%6 = OpLabel"},
1548             {7, "%7 = OpLabel"},
1549           },
1550           // uses
1551           {
1552             {3, {"%1 = OpTypeFunction %3"}},
1553             {5, {"OpBranch %5"}},
1554             {6, {"OpBranch %6"}},
1555           }
1556         }
1557       },
1558       }));
1559 // clang-format on
1560 
1561 struct GetAnnotationsTestCase {
1562   const char* code;
1563   uint32_t id;
1564   std::vector<std::string> annotations;
1565 };
1566 
1567 using GetAnnotationsTest = ::testing::TestWithParam<GetAnnotationsTestCase>;
1568 
TEST_P(GetAnnotationsTest,Case)1569 TEST_P(GetAnnotationsTest, Case) {
1570   const GetAnnotationsTestCase& tc = GetParam();
1571 
1572   // Build module.
1573   std::unique_ptr<IRContext> context =
1574       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.code);
1575   ASSERT_NE(nullptr, context);
1576 
1577   // Get annotations
1578   DefUseManager manager(context->module());
1579   auto insts = manager.GetAnnotations(tc.id);
1580 
1581   // Check
1582   ASSERT_EQ(tc.annotations.size(), insts.size())
1583       << "wrong number of annotation instructions";
1584   auto inst_iter = insts.begin();
1585   for (const std::string& expected_anno_inst : tc.annotations) {
1586     EXPECT_EQ(expected_anno_inst, DisassembleInst(*inst_iter))
1587         << "annotation instruction mismatch";
1588     inst_iter++;
1589   }
1590 }
1591 
1592 // clang-format off
1593 INSTANTIATE_TEST_SUITE_P(
1594     TestCase, GetAnnotationsTest,
1595     ::testing::ValuesIn(std::vector<GetAnnotationsTestCase>{
1596       // empty
1597       {"", 0, {}},
1598       // basic
1599       {
1600         // code
1601         "OpDecorate %1 Block "
1602         "OpDecorate %1 RelaxedPrecision "
1603         "%3 = OpTypeInt 32 0 "
1604         "%1 = OpTypeStruct %3",
1605         // id
1606         1,
1607         // annotations
1608         {
1609           "OpDecorate %1 Block",
1610           "OpDecorate %1 RelaxedPrecision",
1611         },
1612       },
1613       // with debug instructions
1614       {
1615         // code
1616         "OpName %1 \"struct_type\" "
1617         "OpName %3 \"int_type\" "
1618         "OpDecorate %1 Block "
1619         "OpDecorate %1 RelaxedPrecision "
1620         "%3 = OpTypeInt 32 0 "
1621         "%1 = OpTypeStruct %3",
1622         // id
1623         1,
1624         // annotations
1625         {
1626           "OpDecorate %1 Block",
1627           "OpDecorate %1 RelaxedPrecision",
1628         },
1629       },
1630       // no annotations
1631       {
1632         // code
1633         "OpName %1 \"struct_type\" "
1634         "OpName %3 \"int_type\" "
1635         "OpDecorate %1 Block "
1636         "OpDecorate %1 RelaxedPrecision "
1637         "%3 = OpTypeInt 32 0 "
1638         "%1 = OpTypeStruct %3",
1639         // id
1640         3,
1641         // annotations
1642         {},
1643       },
1644       // decoration group
1645       {
1646         // code
1647         "OpDecorate %1 Block "
1648         "OpDecorate %1 RelaxedPrecision "
1649         "%1 = OpDecorationGroup "
1650         "OpGroupDecorate %1 %2 %3 "
1651         "%4 = OpTypeInt 32 0 "
1652         "%2 = OpTypeStruct %4 "
1653         "%3 = OpTypeStruct %4 %4",
1654         // id
1655         3,
1656         // annotations
1657         {
1658           "OpGroupDecorate %1 %2 %3",
1659         },
1660       },
1661       // member decorate
1662       {
1663         // code
1664         "OpMemberDecorate %1 0 RelaxedPrecision "
1665         "%2 = OpTypeInt 32 0 "
1666         "%1 = OpTypeStruct %2 %2",
1667         // id
1668         1,
1669         // annotations
1670         {
1671           "OpMemberDecorate %1 0 RelaxedPrecision",
1672         },
1673       },
1674       }));
1675 
1676 using UpdateUsesTest = PassTest<::testing::Test>;
1677 
TEST_F(UpdateUsesTest,KeepOldUses)1678 TEST_F(UpdateUsesTest, KeepOldUses) {
1679   const std::vector<const char*> text = {
1680       // clang-format off
1681       "OpCapability Shader",
1682       "%1 = OpExtInstImport \"GLSL.std.450\"",
1683       "OpMemoryModel Logical GLSL450",
1684       "OpEntryPoint Vertex %main \"main\"",
1685       "OpName %main \"main\"",
1686       "%void = OpTypeVoid",
1687       "%4 = OpTypeFunction %void",
1688       "%uint = OpTypeInt 32 0",
1689       "%uint_5 = OpConstant %uint 5",
1690       "%25 = OpConstant %uint 25",
1691       "%main = OpFunction %void None %4",
1692       "%8 = OpLabel",
1693       "%9 = OpIMul %uint %uint_5 %uint_5",
1694       "%10 = OpIMul %uint %9 %uint_5",
1695       "OpReturn",
1696       "OpFunctionEnd"
1697       // clang-format on
1698   };
1699 
1700   std::unique_ptr<IRContext> context =
1701       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text),
1702                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
1703   ASSERT_NE(nullptr, context);
1704 
1705   DefUseManager* def_use_mgr = context->get_def_use_mgr();
1706   Instruction* def = def_use_mgr->GetDef(9);
1707   Instruction* use = def_use_mgr->GetDef(10);
1708   def->SetOpcode(spv::Op::OpCopyObject);
1709   def->SetInOperands({{SPV_OPERAND_TYPE_ID, {25}}});
1710   context->UpdateDefUse(def);
1711 
1712   auto scanUser = [&](Instruction* user) { return user != use; };
1713   bool userFound = !def_use_mgr->WhileEachUser(def, scanUser);
1714 
1715   EXPECT_TRUE(userFound);
1716 }
1717 // clang-format on
1718 
1719 }  // namespace
1720 }  // namespace analysis
1721 }  // namespace opt
1722 }  // namespace spvtools
1723