1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
17
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/str_cat.h"
20 #include "llvm/IR/BasicBlock.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/Instructions.h"
23 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
24 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
27 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
28 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
29 #include "tensorflow/core/platform/logging.h"
30
31 namespace xla {
32 namespace gpu {
33
34 using absl::StrAppend;
35 using absl::StrCat;
36
EmitBasePointersForHlos(absl::Span<const HloInstruction * const> io_hlos,absl::Span<const HloInstruction * const> non_io_hlos)37 void HloToIrBindings::EmitBasePointersForHlos(
38 absl::Span<const HloInstruction* const> io_hlos,
39 absl::Span<const HloInstruction* const> non_io_hlos) {
40 CHECK(is_nested_);
41
42 // I/O HLOs are bound to the arguments of the current IR function,
43 // *excluding* the output argument, which is added to non-I/O HLOs.
44 // I.e.,
45 //
46 // void IrFunction(io_0, io_1, ..., io_{m-1}, output_arg);
47 llvm::Function* function = b_->GetInsertBlock()->getParent();
48 CHECK_EQ(io_hlos.size() + 1, function->arg_size());
49
50 // An HLO can have duplicated operands. This data structure remembers which
51 // operand HLOs are already bound to avoid rebinding the same HLO.
52 absl::flat_hash_set<const HloInstruction*> already_bound_for_this_function;
53 auto arg_iter = function->arg_begin();
54 for (const HloInstruction* io_hlo : io_hlos) {
55 CHECK(io_hlo == io_hlo->parent()->root_instruction() ||
56 !absl::c_count(non_io_hlos, io_hlo))
57 << "IO HLOs and non-IO HLOs should be disjoint";
58 if (!already_bound_for_this_function.contains(io_hlo)) {
59 BindHloToIrValue(*io_hlo, &*arg_iter);
60 already_bound_for_this_function.insert(io_hlo);
61 }
62 ++arg_iter;
63 }
64
65 // Name and skip the output parameter.
66 arg_iter->setName("output_arg");
67 ++arg_iter;
68
69 for (const HloInstruction* non_io_hlo : non_io_hlos) {
70 if (already_bound_for_this_function.contains(non_io_hlo)) {
71 continue;
72 }
73 already_bound_for_this_function.insert(non_io_hlo);
74
75 if (non_io_hlo->opcode() == HloOpcode::kGetTupleElement) {
76 continue;
77 }
78
79 ShapeUtil::ForEachSubshape(
80 non_io_hlo->shape(),
81 [&](const Shape& /*subshape*/, const ShapeIndex& index) {
82 if (non_io_hlo->opcode() == HloOpcode::kConstant) {
83 llvm::Value* global_for_constant = module_->getGlobalVariable(
84 llvm_ir::ConstantHloToGlobalName(*non_io_hlo));
85 CHECK(global_for_constant)
86 << llvm_ir::ConstantHloToGlobalName(*non_io_hlo);
87 BindHloToIrValue(*non_io_hlo, global_for_constant);
88 } else {
89 llvm::Type* pointee_type =
90 llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_);
91 BindHloToIrValue(*non_io_hlo,
92 llvm_ir::EmitAllocaAtFunctionEntry(
93 pointee_type, /*name=*/"", b_),
94 index);
95 }
96 });
97 }
98 }
99
EmitGetTupleElement(const HloInstruction * gte,llvm::Value * base_ptr)100 llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
101 llvm::Value* base_ptr) {
102 // TODO(b/26344050): tighten the alignment based on the real element type.
103 if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) {
104 return llvm_ir::EmitGetTupleElement(
105 gte->shape(), gte->tuple_index(), /*alignment=*/1,
106 GetTypedIrValue(*gte->operand(0), {}, base_ptr),
107 llvm_ir::ShapeToIrType(gte->operand(0)->shape(), module_), b_);
108 }
109 return llvm_ir::EmitGetTupleElement(
110 gte->shape(), gte->tuple_index(), /*alignment=*/1,
111 EmitGetTupleElement(gte->operand(0), base_ptr),
112 llvm_ir::ShapeToIrType(gte->operand(0)->shape(), module_), b_);
113 }
114
115 // Returns true if `value` has a name that should not be changed.
HasMeaningfulName(llvm::Value * value)116 static bool HasMeaningfulName(llvm::Value* value) {
117 if (auto* global = llvm::dyn_cast<llvm::GlobalValue>(value)) {
118 return global->getLinkage() != llvm::GlobalValue::PrivateLinkage;
119 }
120 return false;
121 }
122
CastToTypedValue(const Shape & shape,llvm::Value * ir_value,llvm::IRBuilder<> * b)123 llvm::Value* CastToTypedValue(const Shape& shape, llvm::Value* ir_value,
124 llvm::IRBuilder<>* b) {
125 llvm::Type* pointee_type =
126 llvm_ir::ShapeToIrType(shape, b->GetInsertBlock()->getModule());
127
128 llvm::Type* dest_type = pointee_type->getPointerTo();
129
130 llvm::Value* typed_ir_value;
131 if (llvm::isa<llvm::GlobalVariable>(ir_value)) {
132 typed_ir_value = llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast(
133 llvm::cast<llvm::GlobalVariable>(ir_value), dest_type);
134 } else {
135 typed_ir_value = b->CreatePointerBitCastOrAddrSpaceCast(
136 ir_value, pointee_type->getPointerTo());
137 }
138 return typed_ir_value;
139 }
140
GetTypedIrValue(const HloInstruction & hlo,ShapeIndexView shape_index,llvm::Value * ir_value)141 llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
142 ShapeIndexView shape_index,
143 llvm::Value* ir_value) {
144 auto typed_ir_value = CastToTypedValue(
145 ShapeUtil::GetSubshape(hlo.shape(), shape_index), ir_value, b_);
146 if (!HasMeaningfulName(ir_value)) {
147 ir_value->setName(llvm_ir::IrName(&hlo, "raw"));
148 }
149 if (!HasMeaningfulName(typed_ir_value)) {
150 typed_ir_value->setName(llvm_ir::IrName(&hlo, "typed"));
151 }
152 return typed_ir_value;
153 }
154
BindHloToIrValue(const HloInstruction & hlo,llvm::Value * ir_value,ShapeIndexView shape_index)155 void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo,
156 llvm::Value* ir_value,
157 ShapeIndexView shape_index) {
158 VLOG(2) << "Binding " << hlo.ToString();
159
160 const Shape& hlo_shape = hlo.shape();
161 llvm::Value* typed_ir_value = GetTypedIrValue(hlo, shape_index, ir_value);
162
163 if (!BoundToIrValue(hlo)) {
164 // Set the root of ShapeTree first before assigning the element ir value.
165 InsertOrDie(&base_ptrs_, &hlo, ShapeTree<llvm::Value*>(hlo_shape, nullptr));
166 }
167 *(base_ptrs_[&hlo].mutable_element(shape_index)) = typed_ir_value;
168 }
169
GetIrArray(const HloInstruction & hlo,const HloInstruction & consumer,const ShapeIndex & shape_index)170 llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo,
171 const HloInstruction& consumer,
172 const ShapeIndex& shape_index) {
173 CHECK(is_nested_)
174 << "IrEmitterUnnested should instead use LMHLO to get the IrArray";
175
176 llvm::Value* base_ptr = GetBasePointer(hlo, shape_index);
177 Shape new_shape = ShapeUtil::GetSubshape(hlo.shape(), shape_index);
178 llvm::Type* pointee_type = llvm_ir::ShapeToIrType(new_shape, module_);
179 CHECK_NE(base_ptr, nullptr)
180 << "Buffer not assigned for shape_index " << shape_index.ToString()
181 << " of " << hlo.ToString();
182 llvm_ir::IrArray ir_array(base_ptr, pointee_type, new_shape);
183
184 return ir_array;
185 }
186
UnbindAllLocalIrValues()187 void HloToIrBindings::UnbindAllLocalIrValues() {
188 std::vector<const HloInstruction*> hlos_to_unbind;
189 for (auto& key_value : base_ptrs_) {
190 if (!llvm::isa<llvm::GlobalVariable>(
191 (key_value.second.element({}))->stripPointerCasts())) {
192 hlos_to_unbind.push_back(key_value.first);
193 }
194 }
195 for (const HloInstruction* hlo_to_unbind : hlos_to_unbind) {
196 VLOG(2) << "Unbinding " << hlo_to_unbind->ToString();
197 base_ptrs_.erase(hlo_to_unbind);
198 }
199 }
200
ToString() const201 std::string HloToIrBindings::ToString() const {
202 std::string s = StrCat("** HloToIrBindings **\n");
203 StrAppend(&s, " is_nested_=", is_nested_, "\n");
204 StrAppend(&s,
205 " temp_buffer_base_=", llvm_ir::DumpToString(*temp_buffer_base_),
206 "\n");
207
208 if (base_ptrs_.empty()) {
209 return s;
210 }
211
212 // Iterate over all computations in the module in topological order, and print
213 // out the base pointers we have in each computation in topological order.
214 for (const HloComputation* computation :
215 base_ptrs_.begin()->first->GetModule()->MakeComputationPostOrder()) {
216 bool is_first = true;
217 for (const HloInstruction* instr :
218 computation->MakeInstructionPostOrder()) {
219 auto it = base_ptrs_.find(instr);
220 if (it == base_ptrs_.end()) {
221 continue;
222 }
223 if (is_first) {
224 StrAppend(&s, " Base pointers for computation ", computation->name(),
225 ":\n");
226 is_first = false;
227 }
228 StrAppend(&s, " ", instr->ToString());
229
230 const ShapeTree<llvm::Value*>& shape_tree = it->second;
231 if (!instr->shape().IsTuple()) {
232 const llvm::Value* val = shape_tree.begin()->second;
233 StrAppend(&s, " -> ", llvm_ir::DumpToString(*val), "\n");
234 continue;
235 }
236
237 StrAppend(&s, "\n");
238 for (auto shape_it = shape_tree.begin(); shape_it != shape_tree.end();
239 ++shape_it) {
240 llvm::Value* val = shape_it->second;
241 StrAppend(&s, " ", shape_it->first.ToString(), " -> ",
242 (val != nullptr ? llvm_ir::DumpToString(*val) : "null"),
243 "\n");
244 }
245 }
246 }
247 return s;
248 }
249
250 } // namespace gpu
251 } // namespace xla
252