1 // Copyright (c) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/opextinst_forward_ref_fixup_pass.h"
16
17 #include <string>
18 #include <unordered_set>
19
20 #include "source/extensions.h"
21 #include "source/opt/ir_context.h"
22 #include "source/opt/module.h"
23 #include "type_manager.h"
24
25 namespace spvtools {
26 namespace opt {
27 namespace {
28
29 // Returns true if the instruction |inst| has a forward reference to another
30 // debug instruction.
31 // |debug_ids| contains the list of IDs belonging to debug instructions.
32 // |seen_ids| contains the list of IDs already seen.
HasForwardReference(const Instruction & inst,const std::unordered_set<uint32_t> & debug_ids,const std::unordered_set<uint32_t> & seen_ids)33 bool HasForwardReference(const Instruction& inst,
34 const std::unordered_set<uint32_t>& debug_ids,
35 const std::unordered_set<uint32_t>& seen_ids) {
36 const uint32_t num_in_operands = inst.NumInOperands();
37 for (uint32_t i = 0; i < num_in_operands; ++i) {
38 const Operand& op = inst.GetInOperand(i);
39 if (!spvIsIdType(op.type)) continue;
40
41 if (debug_ids.count(op.AsId()) == 0) continue;
42
43 if (seen_ids.count(op.AsId()) == 0) return true;
44 }
45
46 return false;
47 }
48
49 // Replace |inst| opcode with OpExtInstWithForwardRefsKHR or OpExtInst
50 // if required to comply with forward references.
ReplaceOpcodeIfRequired(Instruction & inst,bool hasForwardReferences)51 bool ReplaceOpcodeIfRequired(Instruction& inst, bool hasForwardReferences) {
52 if (hasForwardReferences &&
53 inst.opcode() != spv::Op::OpExtInstWithForwardRefsKHR)
54 inst.SetOpcode(spv::Op::OpExtInstWithForwardRefsKHR);
55 else if (!hasForwardReferences && inst.opcode() != spv::Op::OpExtInst)
56 inst.SetOpcode(spv::Op::OpExtInst);
57 else
58 return false;
59 return true;
60 }
61
62 // Returns all the result IDs of the instructions in |range|.
gatherResultIds(const IteratorRange<Module::inst_iterator> & range)63 std::unordered_set<uint32_t> gatherResultIds(
64 const IteratorRange<Module::inst_iterator>& range) {
65 std::unordered_set<uint32_t> output;
66 for (const auto& it : range) output.insert(it.result_id());
67 return output;
68 }
69
70 } // namespace
71
Process()72 Pass::Status OpExtInstWithForwardReferenceFixupPass::Process() {
73 std::unordered_set<uint32_t> seen_ids =
74 gatherResultIds(get_module()->ext_inst_imports());
75 std::unordered_set<uint32_t> debug_ids =
76 gatherResultIds(get_module()->ext_inst_debuginfo());
77 for (uint32_t id : seen_ids) debug_ids.insert(id);
78
79 bool moduleChanged = false;
80 bool hasAtLeastOneForwardReference = false;
81 IRContext* ctx = context();
82 for (Instruction& inst : get_module()->ext_inst_debuginfo()) {
83 if (inst.opcode() != spv::Op::OpExtInst &&
84 inst.opcode() != spv::Op::OpExtInstWithForwardRefsKHR)
85 continue;
86
87 seen_ids.insert(inst.result_id());
88 bool hasForwardReferences = HasForwardReference(inst, debug_ids, seen_ids);
89 hasAtLeastOneForwardReference |= hasForwardReferences;
90
91 if (ReplaceOpcodeIfRequired(inst, hasForwardReferences)) {
92 moduleChanged = true;
93 ctx->AnalyzeUses(&inst);
94 }
95 }
96
97 if (hasAtLeastOneForwardReference !=
98 ctx->get_feature_mgr()->HasExtension(
99 kSPV_KHR_relaxed_extended_instruction)) {
100 if (hasAtLeastOneForwardReference)
101 ctx->AddExtension("SPV_KHR_relaxed_extended_instruction");
102 else
103 ctx->RemoveExtension(Extension::kSPV_KHR_relaxed_extended_instruction);
104 moduleChanged = true;
105 }
106
107 return moduleChanged ? Status::SuccessWithChange
108 : Status::SuccessWithoutChange;
109 }
110
111 } // namespace opt
112 } // namespace spvtools
113