1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "llvm/IR/BasicBlock.h"
24 #include "llvm/IR/Function.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/Instructions.h"
27 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
28 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
29 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
30 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
35 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
36 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
37 #include "tensorflow/compiler/xla/service/name_uniquer.h"
38 #include "tensorflow/core/lib/core/status.h"
39
40 namespace xla {
41 namespace gpu {
42
IrEmitterNested(const HloModuleConfig & hlo_module_config,const HloComputation & nested_computation,IrEmitterContext * ir_emitter_context)43 IrEmitterNested::IrEmitterNested(const HloModuleConfig& hlo_module_config,
44 const HloComputation& nested_computation,
45 IrEmitterContext* ir_emitter_context)
46 : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/true),
47 nested_computation_(nested_computation) {}
48
Create(const HloModuleConfig & hlo_module_config,const HloComputation & nested_computation,IrEmitterContext * ir_emitter_context)49 StatusOr<std::unique_ptr<IrEmitterNested>> IrEmitterNested::Create(
50 const HloModuleConfig& hlo_module_config,
51 const HloComputation& nested_computation,
52 IrEmitterContext* ir_emitter_context) {
53 std::unique_ptr<IrEmitterNested> emitter(new IrEmitterNested(
54 hlo_module_config, nested_computation, ir_emitter_context));
55 TF_RETURN_IF_ERROR(emitter->EmitConstants(nested_computation));
56 return emitter;
57 }
58
59 // Nested function serves the same purpose on GPU as a thread-local function on
60 // a CPU.
CodegenNestedComputation()61 Status IrEmitterNested::CodegenNestedComputation() {
62 std::vector<const HloInstruction*> io_hlos;
63 std::vector<llvm::Type*> argument_types;
64 std::vector<int64_t> argument_dereferenceable_bytes;
65 const auto& params = nested_computation_.parameter_instructions();
66 const auto n = params.size() + 1;
67 io_hlos.reserve(n - 1);
68 argument_types.reserve(n);
69 argument_dereferenceable_bytes.reserve(n);
70 for (const HloInstruction* param : params) {
71 io_hlos.push_back(param);
72 const Shape& param_shape = param->shape();
73 argument_types.push_back(
74 llvm_ir::ShapeToIrType(param_shape, module_)->getPointerTo());
75 int64_t param_size =
76 llvm_ir::ByteSizeOf(param_shape, module_->getDataLayout());
77 argument_dereferenceable_bytes.push_back(param_size);
78 }
79
80 const HloInstruction* root = nested_computation_.root_instruction();
81 {
82 const Shape& root_shape = root->shape();
83 argument_types.push_back(
84 llvm_ir::ShapeToIrType(root_shape, module_)->getPointerTo());
85 int64_t root_size = llvm_ir::ByteSizeOf(
86 root_shape, ir_emitter_context_->llvm_module()->getDataLayout());
87 argument_dereferenceable_bytes.push_back(root_size);
88 }
89
90 llvm::FunctionType* function_type =
91 llvm::FunctionType::get(b_.getVoidTy(), argument_types, false);
92 llvm::Function* function = llvm::Function::Create(
93 function_type, // The function type.
94 llvm::GlobalValue::InternalLinkage, // The linkage type.
95 ir_emitter_context_->name_uniquer()->GetUniqueName(
96 llvm_ir::SanitizeFunctionName(
97 nested_computation_.name())), // The name of the function.
98 ir_emitter_context_->llvm_module()); // The parent LLVM module.
99 for (size_t arg_no = 0; arg_no < argument_dereferenceable_bytes.size();
100 ++arg_no) {
101 int64_t arg_size = argument_dereferenceable_bytes[arg_no];
102 if (arg_size > 0) {
103 function->addDereferenceableParamAttr(arg_no, arg_size);
104 }
105 }
106
107 // TODO(b/65380986): Investigate if adding fast math flags for generated
108 // kernels makes sense.
109
110 llvm::BasicBlock* entry_bb =
111 llvm::BasicBlock::Create(function->getContext(), "entry", function);
112 // Emit a "return void" at entry_bb's end, and sets the insert point before
113 // that return instruction.
114 llvm::ReturnInst* ret_instr =
115 llvm::ReturnInst::Create(function->getContext(), entry_bb);
116 b_.SetInsertPoint(ret_instr);
117
118 std::vector<const HloInstruction*> non_io_hlos;
119 non_io_hlos.push_back(root);
120 for (const auto* hlo : nested_computation_.instructions()) {
121 if (hlo->opcode() != HloOpcode::kParameter &&
122 hlo != nested_computation_.root_instruction()) {
123 non_io_hlos.push_back(hlo);
124 }
125 }
126 bindings_.EmitBasePointersForHlos(io_hlos, non_io_hlos);
127
128 TF_RETURN_IF_ERROR(nested_computation_.root_instruction()->Accept(this));
129 b_.SetInsertPoint(ret_instr);
130
131 // Function epilogue: copy the output value back.
132 {
133 // TODO(cheshire) Duplication vs. EmitThreadLocalFunctionEpilogue
134 const HloInstruction* root_instruction =
135 nested_computation_.root_instruction();
136 llvm::Value* root_value = bindings_.GetBasePointer(*root_instruction);
137 const Shape& return_shape = root_instruction->shape();
138
139 // Last argument is the out parameter.
140 llvm::Argument* out_parameter = std::prev(function->arg_end(), 1);
141
142 if (ShapeUtil::IsScalar(return_shape)) {
143 llvm::Value* ret_value =
144 Load(llvm_ir::ShapeToIrType(return_shape, module_), root_value,
145 "load_ret_value");
146 Store(ret_value,
147 BitCast(out_parameter, root_value->getType(), "bitcast_ret_value"));
148 } else {
149 CHECK(return_shape.IsTuple());
150 llvm::Type* tuple_type = llvm_ir::ShapeToIrType(return_shape, module_);
151 llvm::Type* tuple_type_ptr = tuple_type->getPointerTo();
152 llvm::Value* tuple_ptr = BitCast(out_parameter, tuple_type_ptr);
153
154 for (int i = 0; i < return_shape.tuple_shapes_size(); i++) {
155 const Shape& element_shape = return_shape.tuple_shapes(i);
156 llvm::Value* destination = llvm_ir::EmitGetTupleElement(
157 element_shape,
158 /*index=*/i,
159 /*alignment=*/1, tuple_ptr, tuple_type, &b_);
160 llvm::Value* source = llvm_ir::EmitGetTupleElement(
161 element_shape,
162 /*index=*/i,
163 /*alignment=*/1, root_value,
164 llvm_ir::ShapeToIrType(root_instruction->shape(), module_), &b_);
165 Store(Load(llvm_ir::ShapeToIrType(element_shape, module_), source),
166 destination);
167 }
168 }
169 }
170 b_.SetInsertPoint(ret_instr);
171 emitted_function_ = function;
172 return OkStatus();
173 }
174
HandleParameter(HloInstruction * parameter)175 Status IrEmitterNested::HandleParameter(HloInstruction* parameter) {
176 return OkStatus();
177 }
178
EmitTargetElementLoop(const HloInstruction & hlo,const llvm_ir::ElementGenerator & element_generator)179 Status IrEmitterNested::EmitTargetElementLoop(
180 const HloInstruction& hlo,
181 const llvm_ir::ElementGenerator& element_generator) {
182 // For MOF we give the loop emitter an array for every output it should
183 // generate.
184 if (hlo.shape().IsTuple()) {
185 std::vector<llvm_ir::IrArray> target_arrays =
186 ConstructIrArrayForOutputs(hlo);
187 TF_RETURN_IF_ERROR(
188 llvm_ir::LoopEmitter(element_generator, target_arrays, &b_).EmitLoop());
189 llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_);
190 return OkStatus();
191 }
192 return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_)
193 .EmitLoop();
194 }
195
EmitConstants(const HloComputation & computation)196 Status IrEmitterNested::EmitConstants(const HloComputation& computation) {
197 for (HloInstruction* instr : computation.instructions()) {
198 if (instr->opcode() != HloOpcode::kConstant) {
199 continue;
200 }
201 Literal& literal = *Cast<HloConstantInstruction>(instr)->mutable_literal();
202 const bool should_emit_initializer = ShouldEmitLiteralInLlvmIr(literal);
203 llvm::ArrayType* global_type =
204 llvm::ArrayType::get(b_.getInt8Ty(), literal.size_bytes());
205 llvm::Constant* initializer =
206 should_emit_initializer
207 ? llvm_ir::ConvertLiteralToIrConstant(literal, module_)
208 : llvm::ConstantAggregateZero::get(global_type);
209 if (should_emit_initializer) {
210 VLOG(3) << "Emitted initializer for constant with shape "
211 << ShapeUtil::HumanString(literal.shape());
212 }
213
214 // These globals will be looked up by name by GpuExecutable so we need to
215 // give them an external linkage. Not all of their uses are visible in
216 // the LLVM IR (e.g. TupleThunk) so we can't give then a linkage that
217 // merely preserves their names (like available_externally), we also need
218 // to ensure that they stick around even if they're "unused".
219 //
220 // We may have to be more clever here in the future if we notice that we're
221 // keeping around too many globals because of their linkage.
222 std::string global_name = llvm_ir::ConstantHloToGlobalName(*instr);
223
224 llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
225 global_type, /*isConstant=*/should_emit_initializer,
226 llvm::GlobalValue::ExternalLinkage,
227 /*Initializer=*/initializer, global_name,
228 /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
229 /*AddressSpace=*/0,
230 /*isExternallyInitialized=*/false);
231 global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes));
232 ir_emitter_context_->llvm_module()->getGlobalList().push_back(
233 global_for_const);
234
235 GpuExecutable::ConstantInfo info;
236 info.symbol_name = global_name;
237
238 if (!should_emit_initializer) {
239 auto base = static_cast<const uint8_t*>(literal.untyped_data());
240 info.content.assign(base, base + literal.size_bytes());
241 }
242 ir_emitter_context_->constants().push_back(std::move(info));
243 }
244 return OkStatus();
245 }
246
247 } // namespace gpu
248 } // namespace xla
249