xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "llvm/IR/BasicBlock.h"
24 #include "llvm/IR/Function.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/Instructions.h"
27 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
28 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
29 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
30 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
35 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
36 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
37 #include "tensorflow/compiler/xla/service/name_uniquer.h"
38 #include "tensorflow/core/lib/core/status.h"
39 
40 namespace xla {
41 namespace gpu {
42 
IrEmitterNested(const HloModuleConfig & hlo_module_config,const HloComputation & nested_computation,IrEmitterContext * ir_emitter_context)43 IrEmitterNested::IrEmitterNested(const HloModuleConfig& hlo_module_config,
44                                  const HloComputation& nested_computation,
45                                  IrEmitterContext* ir_emitter_context)
46     : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/true),
47       nested_computation_(nested_computation) {}
48 
Create(const HloModuleConfig & hlo_module_config,const HloComputation & nested_computation,IrEmitterContext * ir_emitter_context)49 StatusOr<std::unique_ptr<IrEmitterNested>> IrEmitterNested::Create(
50     const HloModuleConfig& hlo_module_config,
51     const HloComputation& nested_computation,
52     IrEmitterContext* ir_emitter_context) {
53   std::unique_ptr<IrEmitterNested> emitter(new IrEmitterNested(
54       hlo_module_config, nested_computation, ir_emitter_context));
55   TF_RETURN_IF_ERROR(emitter->EmitConstants(nested_computation));
56   return emitter;
57 }
58 
59 // Nested function serves the same purpose on GPU as a thread-local function on
60 // a CPU.
CodegenNestedComputation()61 Status IrEmitterNested::CodegenNestedComputation() {
62   std::vector<const HloInstruction*> io_hlos;
63   std::vector<llvm::Type*> argument_types;
64   std::vector<int64_t> argument_dereferenceable_bytes;
65   const auto& params = nested_computation_.parameter_instructions();
66   const auto n = params.size() + 1;
67   io_hlos.reserve(n - 1);
68   argument_types.reserve(n);
69   argument_dereferenceable_bytes.reserve(n);
70   for (const HloInstruction* param : params) {
71     io_hlos.push_back(param);
72     const Shape& param_shape = param->shape();
73     argument_types.push_back(
74         llvm_ir::ShapeToIrType(param_shape, module_)->getPointerTo());
75     int64_t param_size =
76         llvm_ir::ByteSizeOf(param_shape, module_->getDataLayout());
77     argument_dereferenceable_bytes.push_back(param_size);
78   }
79 
80   const HloInstruction* root = nested_computation_.root_instruction();
81   {
82     const Shape& root_shape = root->shape();
83     argument_types.push_back(
84         llvm_ir::ShapeToIrType(root_shape, module_)->getPointerTo());
85     int64_t root_size = llvm_ir::ByteSizeOf(
86         root_shape, ir_emitter_context_->llvm_module()->getDataLayout());
87     argument_dereferenceable_bytes.push_back(root_size);
88   }
89 
90   llvm::FunctionType* function_type =
91       llvm::FunctionType::get(b_.getVoidTy(), argument_types, false);
92   llvm::Function* function = llvm::Function::Create(
93       function_type,                       // The function type.
94       llvm::GlobalValue::InternalLinkage,  // The linkage type.
95       ir_emitter_context_->name_uniquer()->GetUniqueName(
96           llvm_ir::SanitizeFunctionName(
97               nested_computation_.name())),  // The name of the function.
98       ir_emitter_context_->llvm_module());   // The parent LLVM module.
99   for (size_t arg_no = 0; arg_no < argument_dereferenceable_bytes.size();
100        ++arg_no) {
101     int64_t arg_size = argument_dereferenceable_bytes[arg_no];
102     if (arg_size > 0) {
103       function->addDereferenceableParamAttr(arg_no, arg_size);
104     }
105   }
106 
107   // TODO(b/65380986): Investigate if adding fast math flags for generated
108   // kernels makes sense.
109 
110   llvm::BasicBlock* entry_bb =
111       llvm::BasicBlock::Create(function->getContext(), "entry", function);
112   // Emit a "return void" at entry_bb's end, and sets the insert point before
113   // that return instruction.
114   llvm::ReturnInst* ret_instr =
115       llvm::ReturnInst::Create(function->getContext(), entry_bb);
116   b_.SetInsertPoint(ret_instr);
117 
118   std::vector<const HloInstruction*> non_io_hlos;
119   non_io_hlos.push_back(root);
120   for (const auto* hlo : nested_computation_.instructions()) {
121     if (hlo->opcode() != HloOpcode::kParameter &&
122         hlo != nested_computation_.root_instruction()) {
123       non_io_hlos.push_back(hlo);
124     }
125   }
126   bindings_.EmitBasePointersForHlos(io_hlos, non_io_hlos);
127 
128   TF_RETURN_IF_ERROR(nested_computation_.root_instruction()->Accept(this));
129   b_.SetInsertPoint(ret_instr);
130 
131   // Function epilogue: copy the output value back.
132   {
133     // TODO(cheshire) Duplication vs. EmitThreadLocalFunctionEpilogue
134     const HloInstruction* root_instruction =
135         nested_computation_.root_instruction();
136     llvm::Value* root_value = bindings_.GetBasePointer(*root_instruction);
137     const Shape& return_shape = root_instruction->shape();
138 
139     // Last argument is the out parameter.
140     llvm::Argument* out_parameter = std::prev(function->arg_end(), 1);
141 
142     if (ShapeUtil::IsScalar(return_shape)) {
143       llvm::Value* ret_value =
144           Load(llvm_ir::ShapeToIrType(return_shape, module_), root_value,
145                "load_ret_value");
146       Store(ret_value,
147             BitCast(out_parameter, root_value->getType(), "bitcast_ret_value"));
148     } else {
149       CHECK(return_shape.IsTuple());
150       llvm::Type* tuple_type = llvm_ir::ShapeToIrType(return_shape, module_);
151       llvm::Type* tuple_type_ptr = tuple_type->getPointerTo();
152       llvm::Value* tuple_ptr = BitCast(out_parameter, tuple_type_ptr);
153 
154       for (int i = 0; i < return_shape.tuple_shapes_size(); i++) {
155         const Shape& element_shape = return_shape.tuple_shapes(i);
156         llvm::Value* destination = llvm_ir::EmitGetTupleElement(
157             element_shape,
158             /*index=*/i,
159             /*alignment=*/1, tuple_ptr, tuple_type, &b_);
160         llvm::Value* source = llvm_ir::EmitGetTupleElement(
161             element_shape,
162             /*index=*/i,
163             /*alignment=*/1, root_value,
164             llvm_ir::ShapeToIrType(root_instruction->shape(), module_), &b_);
165         Store(Load(llvm_ir::ShapeToIrType(element_shape, module_), source),
166               destination);
167       }
168     }
169   }
170   b_.SetInsertPoint(ret_instr);
171   emitted_function_ = function;
172   return OkStatus();
173 }
174 
HandleParameter(HloInstruction * parameter)175 Status IrEmitterNested::HandleParameter(HloInstruction* parameter) {
176   return OkStatus();
177 }
178 
EmitTargetElementLoop(const HloInstruction & hlo,const llvm_ir::ElementGenerator & element_generator)179 Status IrEmitterNested::EmitTargetElementLoop(
180     const HloInstruction& hlo,
181     const llvm_ir::ElementGenerator& element_generator) {
182   // For MOF we give the loop emitter an array for every output it should
183   // generate.
184   if (hlo.shape().IsTuple()) {
185     std::vector<llvm_ir::IrArray> target_arrays =
186         ConstructIrArrayForOutputs(hlo);
187     TF_RETURN_IF_ERROR(
188         llvm_ir::LoopEmitter(element_generator, target_arrays, &b_).EmitLoop());
189     llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_);
190     return OkStatus();
191   }
192   return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_)
193       .EmitLoop();
194 }
195 
EmitConstants(const HloComputation & computation)196 Status IrEmitterNested::EmitConstants(const HloComputation& computation) {
197   for (HloInstruction* instr : computation.instructions()) {
198     if (instr->opcode() != HloOpcode::kConstant) {
199       continue;
200     }
201     Literal& literal = *Cast<HloConstantInstruction>(instr)->mutable_literal();
202     const bool should_emit_initializer = ShouldEmitLiteralInLlvmIr(literal);
203     llvm::ArrayType* global_type =
204         llvm::ArrayType::get(b_.getInt8Ty(), literal.size_bytes());
205     llvm::Constant* initializer =
206         should_emit_initializer
207             ? llvm_ir::ConvertLiteralToIrConstant(literal, module_)
208             : llvm::ConstantAggregateZero::get(global_type);
209     if (should_emit_initializer) {
210       VLOG(3) << "Emitted initializer for constant with shape "
211               << ShapeUtil::HumanString(literal.shape());
212     }
213 
214     // These globals will be looked up by name by GpuExecutable so we need to
215     // give them an external linkage.  Not all of their uses are visible in
216     // the LLVM IR (e.g. TupleThunk) so we can't give then a linkage that
217     // merely preserves their names (like available_externally), we also need
218     // to ensure that they stick around even if they're "unused".
219     //
220     // We may have to be more clever here in the future if we notice that we're
221     // keeping around too many globals because of their linkage.
222     std::string global_name = llvm_ir::ConstantHloToGlobalName(*instr);
223 
224     llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
225         global_type, /*isConstant=*/should_emit_initializer,
226         llvm::GlobalValue::ExternalLinkage,
227         /*Initializer=*/initializer, global_name,
228         /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
229         /*AddressSpace=*/0,
230         /*isExternallyInitialized=*/false);
231     global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes));
232     ir_emitter_context_->llvm_module()->getGlobalList().push_back(
233         global_for_const);
234 
235     GpuExecutable::ConstantInfo info;
236     info.symbol_name = global_name;
237 
238     if (!should_emit_initializer) {
239       auto base = static_cast<const uint8_t*>(literal.untyped_data());
240       info.content.assign(base, base + literal.size_bytes());
241     }
242     ir_emitter_context_->constants().push_back(std::move(info));
243   }
244   return OkStatus();
245 }
246 
247 }  // namespace gpu
248 }  // namespace xla
249