1 // Copyright (c) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/opextinst_forward_ref_fixup_pass.h"
16 
17 #include <string>
18 #include <unordered_set>
19 
20 #include "source/extensions.h"
21 #include "source/opt/ir_context.h"
22 #include "source/opt/module.h"
23 #include "type_manager.h"
24 
25 namespace spvtools {
26 namespace opt {
27 namespace {
28 
29 // Returns true if the instruction |inst| has a forward reference to another
30 // debug instruction.
31 // |debug_ids| contains the list of IDs belonging to debug instructions.
32 // |seen_ids| contains the list of IDs already seen.
HasForwardReference(const Instruction & inst,const std::unordered_set<uint32_t> & debug_ids,const std::unordered_set<uint32_t> & seen_ids)33 bool HasForwardReference(const Instruction& inst,
34                          const std::unordered_set<uint32_t>& debug_ids,
35                          const std::unordered_set<uint32_t>& seen_ids) {
36   const uint32_t num_in_operands = inst.NumInOperands();
37   for (uint32_t i = 0; i < num_in_operands; ++i) {
38     const Operand& op = inst.GetInOperand(i);
39     if (!spvIsIdType(op.type)) continue;
40 
41     if (debug_ids.count(op.AsId()) == 0) continue;
42 
43     if (seen_ids.count(op.AsId()) == 0) return true;
44   }
45 
46   return false;
47 }
48 
49 // Replace |inst| opcode with OpExtInstWithForwardRefsKHR or OpExtInst
50 // if required to comply with forward references.
ReplaceOpcodeIfRequired(Instruction & inst,bool hasForwardReferences)51 bool ReplaceOpcodeIfRequired(Instruction& inst, bool hasForwardReferences) {
52   if (hasForwardReferences &&
53       inst.opcode() != spv::Op::OpExtInstWithForwardRefsKHR)
54     inst.SetOpcode(spv::Op::OpExtInstWithForwardRefsKHR);
55   else if (!hasForwardReferences && inst.opcode() != spv::Op::OpExtInst)
56     inst.SetOpcode(spv::Op::OpExtInst);
57   else
58     return false;
59   return true;
60 }
61 
62 // Returns all the result IDs of the instructions in |range|.
gatherResultIds(const IteratorRange<Module::inst_iterator> & range)63 std::unordered_set<uint32_t> gatherResultIds(
64     const IteratorRange<Module::inst_iterator>& range) {
65   std::unordered_set<uint32_t> output;
66   for (const auto& it : range) output.insert(it.result_id());
67   return output;
68 }
69 
70 }  // namespace
71 
Process()72 Pass::Status OpExtInstWithForwardReferenceFixupPass::Process() {
73   std::unordered_set<uint32_t> seen_ids =
74       gatherResultIds(get_module()->ext_inst_imports());
75   std::unordered_set<uint32_t> debug_ids =
76       gatherResultIds(get_module()->ext_inst_debuginfo());
77   for (uint32_t id : seen_ids) debug_ids.insert(id);
78 
79   bool moduleChanged = false;
80   bool hasAtLeastOneForwardReference = false;
81   IRContext* ctx = context();
82   for (Instruction& inst : get_module()->ext_inst_debuginfo()) {
83     if (inst.opcode() != spv::Op::OpExtInst &&
84         inst.opcode() != spv::Op::OpExtInstWithForwardRefsKHR)
85       continue;
86 
87     seen_ids.insert(inst.result_id());
88     bool hasForwardReferences = HasForwardReference(inst, debug_ids, seen_ids);
89     hasAtLeastOneForwardReference |= hasForwardReferences;
90 
91     if (ReplaceOpcodeIfRequired(inst, hasForwardReferences)) {
92       moduleChanged = true;
93       ctx->AnalyzeUses(&inst);
94     }
95   }
96 
97   if (hasAtLeastOneForwardReference !=
98       ctx->get_feature_mgr()->HasExtension(
99           kSPV_KHR_relaxed_extended_instruction)) {
100     if (hasAtLeastOneForwardReference)
101       ctx->AddExtension("SPV_KHR_relaxed_extended_instruction");
102     else
103       ctx->RemoveExtension(Extension::kSPV_KHR_relaxed_extended_instruction);
104     moduleChanged = true;
105   }
106 
107   return moduleChanged ? Status::SuccessWithChange
108                        : Status::SuccessWithoutChange;
109 }
110 
111 }  // namespace opt
112 }  // namespace spvtools
113