xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
17 
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/str_cat.h"
20 #include "llvm/IR/BasicBlock.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/Instructions.h"
23 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
24 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
27 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
28 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
29 #include "tensorflow/core/platform/logging.h"
30 
31 namespace xla {
32 namespace gpu {
33 
34 using absl::StrAppend;
35 using absl::StrCat;
36 
EmitBasePointersForHlos(absl::Span<const HloInstruction * const> io_hlos,absl::Span<const HloInstruction * const> non_io_hlos)37 void HloToIrBindings::EmitBasePointersForHlos(
38     absl::Span<const HloInstruction* const> io_hlos,
39     absl::Span<const HloInstruction* const> non_io_hlos) {
40   CHECK(is_nested_);
41 
42   // I/O HLOs are bound to the arguments of the current IR function,
43   // *excluding* the output argument, which is added to non-I/O HLOs.
44   // I.e.,
45   //
46   // void IrFunction(io_0, io_1, ..., io_{m-1}, output_arg);
47   llvm::Function* function = b_->GetInsertBlock()->getParent();
48   CHECK_EQ(io_hlos.size() + 1, function->arg_size());
49 
50   // An HLO can have duplicated operands. This data structure remembers which
51   // operand HLOs are already bound to avoid rebinding the same HLO.
52   absl::flat_hash_set<const HloInstruction*> already_bound_for_this_function;
53   auto arg_iter = function->arg_begin();
54   for (const HloInstruction* io_hlo : io_hlos) {
55     CHECK(io_hlo == io_hlo->parent()->root_instruction() ||
56           !absl::c_count(non_io_hlos, io_hlo))
57         << "IO HLOs and non-IO HLOs should be disjoint";
58     if (!already_bound_for_this_function.contains(io_hlo)) {
59       BindHloToIrValue(*io_hlo, &*arg_iter);
60       already_bound_for_this_function.insert(io_hlo);
61     }
62     ++arg_iter;
63   }
64 
65   // Name and skip the output parameter.
66   arg_iter->setName("output_arg");
67   ++arg_iter;
68 
69   for (const HloInstruction* non_io_hlo : non_io_hlos) {
70     if (already_bound_for_this_function.contains(non_io_hlo)) {
71       continue;
72     }
73     already_bound_for_this_function.insert(non_io_hlo);
74 
75     if (non_io_hlo->opcode() == HloOpcode::kGetTupleElement) {
76       continue;
77     }
78 
79     ShapeUtil::ForEachSubshape(
80         non_io_hlo->shape(),
81         [&](const Shape& /*subshape*/, const ShapeIndex& index) {
82           if (non_io_hlo->opcode() == HloOpcode::kConstant) {
83             llvm::Value* global_for_constant = module_->getGlobalVariable(
84                 llvm_ir::ConstantHloToGlobalName(*non_io_hlo));
85             CHECK(global_for_constant)
86                 << llvm_ir::ConstantHloToGlobalName(*non_io_hlo);
87             BindHloToIrValue(*non_io_hlo, global_for_constant);
88           } else {
89             llvm::Type* pointee_type =
90                 llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_);
91             BindHloToIrValue(*non_io_hlo,
92                              llvm_ir::EmitAllocaAtFunctionEntry(
93                                  pointee_type, /*name=*/"", b_),
94                              index);
95           }
96         });
97   }
98 }
99 
EmitGetTupleElement(const HloInstruction * gte,llvm::Value * base_ptr)100 llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
101                                                   llvm::Value* base_ptr) {
102   // TODO(b/26344050): tighten the alignment based on the real element type.
103   if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) {
104     return llvm_ir::EmitGetTupleElement(
105         gte->shape(), gte->tuple_index(), /*alignment=*/1,
106         GetTypedIrValue(*gte->operand(0), {}, base_ptr),
107         llvm_ir::ShapeToIrType(gte->operand(0)->shape(), module_), b_);
108   }
109   return llvm_ir::EmitGetTupleElement(
110       gte->shape(), gte->tuple_index(), /*alignment=*/1,
111       EmitGetTupleElement(gte->operand(0), base_ptr),
112       llvm_ir::ShapeToIrType(gte->operand(0)->shape(), module_), b_);
113 }
114 
115 // Returns true if `value` has a name that should not be changed.
HasMeaningfulName(llvm::Value * value)116 static bool HasMeaningfulName(llvm::Value* value) {
117   if (auto* global = llvm::dyn_cast<llvm::GlobalValue>(value)) {
118     return global->getLinkage() != llvm::GlobalValue::PrivateLinkage;
119   }
120   return false;
121 }
122 
CastToTypedValue(const Shape & shape,llvm::Value * ir_value,llvm::IRBuilder<> * b)123 llvm::Value* CastToTypedValue(const Shape& shape, llvm::Value* ir_value,
124                               llvm::IRBuilder<>* b) {
125   llvm::Type* pointee_type =
126       llvm_ir::ShapeToIrType(shape, b->GetInsertBlock()->getModule());
127 
128   llvm::Type* dest_type = pointee_type->getPointerTo();
129 
130   llvm::Value* typed_ir_value;
131   if (llvm::isa<llvm::GlobalVariable>(ir_value)) {
132     typed_ir_value = llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast(
133         llvm::cast<llvm::GlobalVariable>(ir_value), dest_type);
134   } else {
135     typed_ir_value = b->CreatePointerBitCastOrAddrSpaceCast(
136         ir_value, pointee_type->getPointerTo());
137   }
138   return typed_ir_value;
139 }
140 
GetTypedIrValue(const HloInstruction & hlo,ShapeIndexView shape_index,llvm::Value * ir_value)141 llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
142                                               ShapeIndexView shape_index,
143                                               llvm::Value* ir_value) {
144   auto typed_ir_value = CastToTypedValue(
145       ShapeUtil::GetSubshape(hlo.shape(), shape_index), ir_value, b_);
146   if (!HasMeaningfulName(ir_value)) {
147     ir_value->setName(llvm_ir::IrName(&hlo, "raw"));
148   }
149   if (!HasMeaningfulName(typed_ir_value)) {
150     typed_ir_value->setName(llvm_ir::IrName(&hlo, "typed"));
151   }
152   return typed_ir_value;
153 }
154 
BindHloToIrValue(const HloInstruction & hlo,llvm::Value * ir_value,ShapeIndexView shape_index)155 void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo,
156                                        llvm::Value* ir_value,
157                                        ShapeIndexView shape_index) {
158   VLOG(2) << "Binding " << hlo.ToString();
159 
160   const Shape& hlo_shape = hlo.shape();
161   llvm::Value* typed_ir_value = GetTypedIrValue(hlo, shape_index, ir_value);
162 
163   if (!BoundToIrValue(hlo)) {
164     // Set the root of ShapeTree first before assigning the element ir value.
165     InsertOrDie(&base_ptrs_, &hlo, ShapeTree<llvm::Value*>(hlo_shape, nullptr));
166   }
167   *(base_ptrs_[&hlo].mutable_element(shape_index)) = typed_ir_value;
168 }
169 
GetIrArray(const HloInstruction & hlo,const HloInstruction & consumer,const ShapeIndex & shape_index)170 llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo,
171                                              const HloInstruction& consumer,
172                                              const ShapeIndex& shape_index) {
173   CHECK(is_nested_)
174       << "IrEmitterUnnested should instead use LMHLO to get the IrArray";
175 
176   llvm::Value* base_ptr = GetBasePointer(hlo, shape_index);
177   Shape new_shape = ShapeUtil::GetSubshape(hlo.shape(), shape_index);
178   llvm::Type* pointee_type = llvm_ir::ShapeToIrType(new_shape, module_);
179   CHECK_NE(base_ptr, nullptr)
180       << "Buffer not assigned for shape_index " << shape_index.ToString()
181       << " of " << hlo.ToString();
182   llvm_ir::IrArray ir_array(base_ptr, pointee_type, new_shape);
183 
184   return ir_array;
185 }
186 
UnbindAllLocalIrValues()187 void HloToIrBindings::UnbindAllLocalIrValues() {
188   std::vector<const HloInstruction*> hlos_to_unbind;
189   for (auto& key_value : base_ptrs_) {
190     if (!llvm::isa<llvm::GlobalVariable>(
191             (key_value.second.element({}))->stripPointerCasts())) {
192       hlos_to_unbind.push_back(key_value.first);
193     }
194   }
195   for (const HloInstruction* hlo_to_unbind : hlos_to_unbind) {
196     VLOG(2) << "Unbinding " << hlo_to_unbind->ToString();
197     base_ptrs_.erase(hlo_to_unbind);
198   }
199 }
200 
ToString() const201 std::string HloToIrBindings::ToString() const {
202   std::string s = StrCat("** HloToIrBindings **\n");
203   StrAppend(&s, "  is_nested_=", is_nested_, "\n");
204   StrAppend(&s,
205             "  temp_buffer_base_=", llvm_ir::DumpToString(*temp_buffer_base_),
206             "\n");
207 
208   if (base_ptrs_.empty()) {
209     return s;
210   }
211 
212   // Iterate over all computations in the module in topological order, and print
213   // out the base pointers we have in each computation in topological order.
214   for (const HloComputation* computation :
215        base_ptrs_.begin()->first->GetModule()->MakeComputationPostOrder()) {
216     bool is_first = true;
217     for (const HloInstruction* instr :
218          computation->MakeInstructionPostOrder()) {
219       auto it = base_ptrs_.find(instr);
220       if (it == base_ptrs_.end()) {
221         continue;
222       }
223       if (is_first) {
224         StrAppend(&s, "  Base pointers for computation ", computation->name(),
225                   ":\n");
226         is_first = false;
227       }
228       StrAppend(&s, "    ", instr->ToString());
229 
230       const ShapeTree<llvm::Value*>& shape_tree = it->second;
231       if (!instr->shape().IsTuple()) {
232         const llvm::Value* val = shape_tree.begin()->second;
233         StrAppend(&s, " -> ", llvm_ir::DumpToString(*val), "\n");
234         continue;
235       }
236 
237       StrAppend(&s, "\n");
238       for (auto shape_it = shape_tree.begin(); shape_it != shape_tree.end();
239            ++shape_it) {
240         llvm::Value* val = shape_it->second;
241         StrAppend(&s, "      ", shape_it->first.ToString(), " -> ",
242                   (val != nullptr ? llvm_ir::DumpToString(*val) : "null"),
243                   "\n");
244       }
245     }
246   }
247   return s;
248 }
249 
250 }  // namespace gpu
251 }  // namespace xla
252