1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
17
18 #include "tensorflow/core/platform/logging.h"
19 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
20 #include "absl/algorithm/container.h"
21 #include "llvm/IR/BasicBlock.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/DerivedTypes.h"
24 #include "llvm/IR/Instructions.h"
25 #include "llvm/IR/Module.h"
26 #include "tensorflow/compiler/xla/primitive_util.h"
27 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
28 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
29 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
30 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
31 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
32 #include "tensorflow/compiler/xla/service/hlo_computation.h"
33 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
34 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
35 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
36 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
37 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
38 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
39 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
40 #include "tensorflow/compiler/xla/service/name_uniquer.h"
41 #include "tensorflow/compiler/xla/shape_util.h"
42 #include "tensorflow/compiler/xla/status_macros.h"
43 #include "tensorflow/compiler/xla/types.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/compiler/xla/window_util.h"
46 #include "tensorflow/core/lib/core/errors.h"
47
48 // Convenient function to cast the provided llvm::Value* using IRBuilder
49 // to default address space. This is useful in particular for generating
50 // IR for AMDGPU target, as its kernel variables are in address space 5
51 // instead of the default address space.
AddrCastToDefault(llvm::Value * arg,llvm::IRBuilder<> & b)52 static llvm::Value* AddrCastToDefault(llvm::Value* arg, llvm::IRBuilder<>& b) {
53 llvm::Type* arg_type = arg->getType();
54 CHECK(arg_type->isPointerTy());
55 if (arg_type->getPointerAddressSpace() != 0) {
56 llvm::Type* generic_arg_type = llvm::PointerType::getWithSamePointeeType(
57 llvm::cast<llvm::PointerType>(arg_type), 0);
58 llvm::Value* addrspacecast_arg =
59 b.CreateAddrSpaceCast(arg, generic_arg_type);
60 return addrspacecast_arg;
61 }
62 return arg;
63 }
64
65 namespace xla {
66
67 using llvm_ir::IrName;
68 using llvm_ir::SetToFirstInsertPoint;
69
70 namespace gpu {
71
IrEmitter(const HloModuleConfig & hlo_module_config,IrEmitterContext * ir_emitter_context,bool is_nested)72 IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config,
73 IrEmitterContext* ir_emitter_context, bool is_nested)
74 : ir_emitter_context_(ir_emitter_context),
75 module_(ir_emitter_context->llvm_module()),
76 b_(module_->getContext()),
77 bindings_(&b_, module_, is_nested),
78 hlo_module_config_(hlo_module_config) {}
79
DefaultAction(HloInstruction * hlo)80 Status IrEmitter::DefaultAction(HloInstruction* hlo) {
81 ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
82 for (const HloInstruction* operand : hlo->operands()) {
83 operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
84 return GetIrArray(*operand, *hlo)
85 .EmitReadArrayElement(index, &b_, operand->name());
86 };
87 }
88 return EmitTargetElementLoop(
89 *hlo, GpuElementalIrEmitter(hlo_module_config_, module_, &b_,
90 GetNestedComputer())
91 .MakeElementGenerator(hlo, operand_to_generator));
92 }
93
HandleConstant(HloInstruction * constant)94 Status IrEmitter::HandleConstant(HloInstruction* constant) {
95 return OkStatus();
96 }
97
HandleAddDependency(HloInstruction * add_dependency)98 Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
99 VLOG(2) << "HandleAddDependency: " << add_dependency->ToString();
100 const HloInstruction* operand = add_dependency->operand(0);
101 // Add_Dependency is a no-op, but we still want to bind it to an llvm::Value
102 // sometimes, e.g., when it's operand is a constant or a bitcast of a
103 // constant.
104 if (bindings_.BoundToIrValue(*operand)) {
105 bindings_.BindHloToIrValue(*add_dependency, GetBasePointer(*operand));
106 }
107 return OkStatus();
108 }
109
HandleGetTupleElement(HloInstruction * get_tuple_element)110 Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
111 auto operand = get_tuple_element->operand(0);
112 CHECK(bindings_.BoundToIrValue(*operand));
113 bindings_.BindHloToIrValue(
114 *get_tuple_element,
115 llvm_ir::EmitGetTupleElement(
116 get_tuple_element->shape(), get_tuple_element->tuple_index(),
117 // TODO(b/26344050): tighten the alignment here
118 // based on the real element type.
119 /*alignment=*/1, GetBasePointer(*operand),
120 llvm_ir::ShapeToIrType(operand->shape(), module_), &b_));
121 return OkStatus();
122 }
123
HandleSend(HloInstruction *)124 Status IrEmitter::HandleSend(HloInstruction*) {
125 return Unimplemented("Send is not implemented on GPU");
126 }
127
HandleSendDone(HloInstruction *)128 Status IrEmitter::HandleSendDone(HloInstruction*) {
129 return Unimplemented("Send-Done is not implemented on GPU");
130 }
131
HandleRecv(HloInstruction *)132 Status IrEmitter::HandleRecv(HloInstruction*) {
133 return Unimplemented("Recv is not implemented on GPU");
134 }
135
HandleRecvDone(HloInstruction *)136 Status IrEmitter::HandleRecvDone(HloInstruction*) {
137 return Unimplemented("Recv-done is not implemented on GPU");
138 }
139
HandleScatter(HloInstruction *)140 Status IrEmitter::HandleScatter(HloInstruction*) {
141 return Unimplemented("Scatter is not implemented on GPUs.");
142 }
143
HandleTuple(HloInstruction * tuple)144 Status IrEmitter::HandleTuple(HloInstruction* tuple) {
145 std::vector<llvm::Value*> base_ptrs;
146 for (const HloInstruction* operand : tuple->operands()) {
147 base_ptrs.push_back(GetBasePointer(*operand));
148 }
149 llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &b_);
150 return OkStatus();
151 }
152
EmitCallToNestedComputation(const HloComputation & nested_computation,absl::Span<llvm::Value * const> operands,llvm::Value * output)153 Status IrEmitter::EmitCallToNestedComputation(
154 const HloComputation& nested_computation,
155 absl::Span<llvm::Value* const> operands, llvm::Value* output) {
156 TF_RET_CHECK(nested_computation.num_parameters() > 0);
157 llvm::Function*& emitted_function =
158 computation_to_ir_function_[&nested_computation];
159 if (emitted_function == nullptr) {
160 TF_ASSIGN_OR_RETURN(
161 auto ir_emitter_nested,
162 IrEmitterNested::Create(hlo_module_config_, nested_computation,
163 ir_emitter_context_));
164 TF_RETURN_IF_ERROR(ir_emitter_nested->CodegenNestedComputation());
165 emitted_function = ir_emitter_nested->GetEmittedFunction();
166 }
167
168 // Operands are in default address space for non-AMDGPU target.
169 // However for AMDGPU target, addrspacecast alloca variables from
170 // addrspace 5 to addrspace 0 is needed.
171 std::vector<llvm::Value*> arguments;
172 absl::c_transform(
173 operands, std::back_inserter(arguments),
174 [this](llvm::Value* arg) { return AddrCastToDefault(arg, b_); });
175
176 llvm::Value* casted_output = AddrCastToDefault(output, b_);
177 arguments.push_back(casted_output);
178
179 Call(emitted_function, arguments);
180
181 return OkStatus();
182 }
183
MaybeEmitDirectAtomicOperation(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address)184 bool IrEmitter::MaybeEmitDirectAtomicOperation(
185 const HloComputation& computation, llvm::Value* output_address,
186 llvm::Value* source_address) {
187 CHECK_EQ(2, computation.num_parameters());
188
189 HloOpcode root_opcode = computation.root_instruction()->opcode();
190 PrimitiveType element_type =
191 computation.root_instruction()->shape().element_type();
192 bool is_atomic_integral = element_type == S32 || element_type == U32 ||
193 element_type == S64 || element_type == U64;
194 llvm::Value* source =
195 Load(llvm_ir::PrimitiveTypeToIrType(element_type, module_),
196 source_address, "source");
197
198 // Just passing along RHS -> atomic store.
199 if (computation.instruction_count() == 2 &&
200 root_opcode == HloOpcode::kParameter &&
201 (element_type == F32 || is_atomic_integral) &&
202 computation.root_instruction()->parameter_number() == 1) {
203 llvm::StoreInst* store = Store(source, output_address);
204 store->setAtomic(llvm::AtomicOrdering::Unordered);
205 // Derive a minimum alignment from the type. The optimizer can increase it
206 // later.
207 store->setAlignment(
208 llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(element_type)));
209 return true;
210 }
211
212 if (computation.instruction_count() != 3) {
213 // We special-case only computations with one computing instruction for now.
214 // Such computation has exactly three instructions given it has two
215 // parameters.
216 return false;
217 }
218
219 if (root_opcode == HloOpcode::kAdd) {
220 llvm::Triple target_triple = llvm::Triple(module_->getTargetTriple());
221 // NVPTX supports atomicAdd on F32 and integer types.
222 if (target_triple.isNVPTX()) {
223 // "atom.add.f64 requires sm_60 or higher."
224 // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom
225 bool f64_atomic_add_supported =
226 ir_emitter_context_->cuda_compute_capability().IsAtLeast(6);
227 bool atomic_add_supported =
228 element_type == F32 ||
229 (f64_atomic_add_supported && element_type == F64);
230 if (atomic_add_supported) {
231 AtomicRMW(llvm::AtomicRMWInst::FAdd, output_address, source,
232 llvm::MaybeAlign(),
233 llvm::AtomicOrdering::SequentiallyConsistent);
234 return true;
235 }
236 }
237
238 if (IsEmittingForAMDGPU() &&
239 (element_type == F32)) /* is atomic add supported? */ {
240 EmitAMDGPUAtomicAdd(output_address, source);
241 return true;
242 }
243
244 if (is_atomic_integral) {
245 // integral + integral
246 AtomicRMW(
247 llvm::AtomicRMWInst::Add, output_address, source, llvm::MaybeAlign(),
248 llvm::AtomicOrdering::SequentiallyConsistent, DetermineSyncScope());
249 return true;
250 }
251 }
252
253 // NVPTX supports atomicMax and atomicMin only on integer types.
254 if (root_opcode == HloOpcode::kMaximum && is_atomic_integral) {
255 // max(integral, integral)
256 auto opcode = primitive_util::IsSignedIntegralType(element_type)
257 ? llvm::AtomicRMWInst::Max
258 : llvm::AtomicRMWInst::UMax;
259 AtomicRMW(opcode, output_address, source, llvm::MaybeAlign(),
260 llvm::AtomicOrdering::SequentiallyConsistent,
261 DetermineSyncScope());
262 return true;
263 }
264
265 if (root_opcode == HloOpcode::kMinimum && is_atomic_integral) {
266 // min(integral, integral)
267 auto opcode = primitive_util::IsSignedIntegralType(element_type)
268 ? llvm::AtomicRMWInst::Min
269 : llvm::AtomicRMWInst::UMin;
270 AtomicRMW(opcode, output_address, source, llvm::MaybeAlign(),
271 llvm::AtomicOrdering::SequentiallyConsistent,
272 DetermineSyncScope());
273 return true;
274 }
275
276 return false;
277 }
278
279 // Implements atomic binary operations using atomic compare-and-swap
280 // (atomicCAS) as follows:
281 // 1. Reads the value from the memory pointed to by output_address and
282 // records it as old_output.
283 // 2. Uses old_output as one of the source operand to perform the binary
284 // operation and stores the result in new_output.
285 // 3. Calls atomicCAS which implements compare-and-swap as an atomic
286 // operation. In particular, atomicCAS reads the value from the memory
287 // pointed to by output_address, and compares the value with old_output. If
288 // the two values equal, new_output is written to the same memory location
289 // and true is returned to indicate that the atomic operation succeeds.
290 // Otherwise, the new value read from the memory is returned. In this case,
291 // the new value is copied to old_output, and steps 2. and 3. are repeated
292 // until atomicCAS succeeds.
293 //
294 // On Nvidia GPUs, atomicCAS can only operate on 32 bit and 64 bit integers. If
295 // the element type of the binary operation is 32 bits or 64 bits, the integer
296 // type of the same size is used for the atomicCAS operation. On the other hand,
297 // if the element type is smaller than 32 bits, int32_t is used for the
298 // atomicCAS operation. In this case, atomicCAS reads and writes 32 bit values
299 // from the memory, which is larger than the memory size required by the
300 // original atomic binary operation. We mask off the last two bits of the
301 // output_address and use the result as an address to read the 32 bit values
302 // from the memory. This can avoid out of bound memory accesses if tensor
303 // buffers are 4 byte aligned and have a size of 4N, an assumption that the
304 // runtime can guarantee.
305 //
306 // The pseudo code is shown below. Variables *_address are pointers to a memory
307 // region with a size equal to the size of the atomicCAS operation, with the
308 // exception that new_output_address is a pointer to a memory region with a size
309 // equal to the element size of the binary operation.
310 //
311 // element_size = sizeof(element_type);
312 // atomic_size = max(32, element_size);
313 // cas_new_output_address = alloca(atomic_size);
314 // cas_old_output_address = alloca(atomic_size);
315 // if (atomic_size != element_size) {
316 // atomic_address = output_address & ((int64_t)(-4));
317 // new_output_address = cas_new_output_address + (output_address & 3);
318 // } else {
319 // atomic_address = output_address;
320 // new_output_address = cas_new_output_address;
321 // }
322 //
323 // *cas_old_output_address = *atomic_address;
324 // do {
325 // *cas_new_output_address = *cas_old_output_address;
326 // *new_output_address = operation(*new_output_address, *source_address);
327 // (*cas_old_output_address, success) =
328 // atomicCAS(atomic_address, *cas_old_output_address,
329 // *cas_new_output_address);
330 // } while (!success);
331 //
EmitAtomicOperationUsingCAS(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address,llvm::Type * element_type)332 Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
333 llvm::Value* output_address,
334 llvm::Value* source_address,
335 llvm::Type* element_type) {
336 llvm::PointerType* output_address_type =
337 llvm::dyn_cast<llvm::PointerType>(output_address->getType());
338 CHECK_NE(output_address_type, nullptr);
339 CHECK(output_address_type->isOpaqueOrPointeeTypeMatches(element_type));
340
341 int element_size = llvm_ir::GetSizeInBits(element_type);
342
343 int atomic_size = (element_size < 32) ? 32 : element_size;
344 llvm::Type* atomic_type = b_.getIntNTy(atomic_size);
345 llvm::Type* atomic_address_type =
346 atomic_type->getPointerTo(output_address_type->getPointerAddressSpace());
347
348 // cas_old_output_address and cas_new_output_address point to the scratch
349 // memory where we store the old and new values for the repeated atomicCAS
350 // operations.
351 llvm::AllocaInst* cas_old_output_address = llvm_ir::EmitAllocaAtFunctionEntry(
352 atomic_type, "cas_old_output_address", &b_);
353 llvm::AllocaInst* cas_new_output_address = llvm_ir::EmitAllocaAtFunctionEntry(
354 atomic_type, "cas_new_output_address", &b_);
355
356 // Emit preparation code to the preheader.
357 llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock();
358
359 llvm::Value* atomic_memory_address;
360 // binop_output_address points to the scratch memory that stores the
361 // result of the binary operation.
362 llvm::Value* binop_output_address;
363 if (element_size < 32) {
364 // Assume the element size is an integer number of bytes.
365 CHECK_EQ((element_size % sizeof(char)), 0);
366 llvm::Type* address_int_type =
367 module_->getDataLayout().getIntPtrType(output_address_type);
368 atomic_memory_address = PtrToInt(output_address, address_int_type);
369 llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3);
370 llvm::Value* offset = And(atomic_memory_address, mask);
371 mask = llvm::ConstantInt::get(address_int_type, -4);
372 atomic_memory_address = And(atomic_memory_address, mask);
373 atomic_memory_address =
374 IntToPtr(atomic_memory_address, atomic_address_type);
375 binop_output_address =
376 Add(PtrToInt(cas_new_output_address, address_int_type), offset);
377 binop_output_address = IntToPtr(
378 binop_output_address,
379 llvm::PointerType::get(
380 element_type,
381 cas_new_output_address->getType()->getPointerAddressSpace()));
382 } else {
383 atomic_memory_address = b_.CreatePointerBitCastOrAddrSpaceCast(
384 output_address, atomic_address_type);
385 binop_output_address = b_.CreatePointerBitCastOrAddrSpaceCast(
386 cas_new_output_address,
387 llvm::PointerType::get(
388 element_type,
389 cas_new_output_address->getType()->getPointerAddressSpace()));
390 }
391
392 // Use the value from the memory that atomicCAS operates on to initialize
393 // cas_old_output.
394 llvm::Value* cas_old_output =
395 Load(atomic_type, atomic_memory_address, "cas_old_output");
396 Store(cas_old_output, cas_old_output_address);
397
398 llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock(
399 b_.GetInsertPoint(), "atomic_op_loop_exit");
400 llvm::BasicBlock* loop_body_bb = llvm::BasicBlock::Create(
401 b_.getContext(), "atomic_op_loop_body", b_.GetInsertBlock()->getParent());
402 b_.SetInsertPoint(loop_body_bb);
403 // Change preheader's successor from loop_exit_bb to loop_body_bb.
404 loop_preheader_bb->getTerminator()->setSuccessor(0, loop_body_bb);
405
406 // Emit the body of the loop that repeatedly invokes atomicCAS.
407 //
408 // Use cas_old_output to initialize cas_new_output.
409 cas_old_output = Load(cas_old_output_address->getAllocatedType(),
410 cas_old_output_address, "cas_old_output");
411 Store(cas_old_output, cas_new_output_address);
412 // Emits code to calculate new_output = operation(old_output, source);
413 TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
414 computation, {binop_output_address, source_address},
415 binop_output_address));
416
417 llvm::Value* cas_new_output = Load(cas_new_output_address->getAllocatedType(),
418 cas_new_output_address, "cas_new_output");
419
420 // If cas_new_output == cas_old_output, we're not asking for anything to
421 // change, so we're done here!
422 llvm::Value* old_eq_new = ICmpEQ(cas_old_output, cas_new_output);
423 llvm::BasicBlock* loop_cas_bb = llvm::BasicBlock::Create(
424 b_.getContext(), "atomic_op_loop_cas", b_.GetInsertBlock()->getParent());
425 CondBr(old_eq_new, loop_exit_bb, loop_cas_bb);
426 b_.SetInsertPoint(loop_cas_bb);
427
428 // Emit code to perform the atomicCAS operation
429 // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output,
430 // cas_new_output);
431 llvm::Value* ret_value = AtomicCmpXchg(
432 atomic_memory_address, cas_old_output, cas_new_output, llvm::MaybeAlign(),
433 llvm::AtomicOrdering::SequentiallyConsistent,
434 llvm::AtomicOrdering::SequentiallyConsistent, DetermineSyncScope());
435
436 // Extract the memory value returned from atomicCAS and store it as
437 // cas_old_output.
438 Store(ExtractValue(ret_value, 0, "cas_old_output"), cas_old_output_address);
439 // Extract the success bit returned from atomicCAS and generate a
440 // conditional branch on the success bit.
441 CondBr(ExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb);
442
443 // Set the insertion point to the exit basic block so that the caller of
444 // this method can continue emitting code to the right place.
445 SetToFirstInsertPoint(loop_exit_bb, &b_);
446 return OkStatus();
447 }
448
EmitAtomicOperationForNestedComputation(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address,llvm::Type * element_type)449 Status IrEmitter::EmitAtomicOperationForNestedComputation(
450 const HloComputation& computation, llvm::Value* output_address,
451 llvm::Value* source_address, llvm::Type* element_type) {
452 if (computation.num_parameters() != 2) {
453 // TODO(b/30258929): We only accept binary computations so far.
454 return Unimplemented(
455 "We only support atomic functions with exactly two parameters, but "
456 "computation %s has %d.",
457 computation.name(), computation.num_parameters());
458 }
459
460 if (MaybeEmitDirectAtomicOperation(computation, output_address,
461 source_address)) {
462 return OkStatus();
463 }
464
465 return EmitAtomicOperationUsingCAS(computation, output_address,
466 source_address, element_type);
467 }
468
IsEmittingForAMDGPU() const469 bool IrEmitter::IsEmittingForAMDGPU() const {
470 llvm::Triple target_triple = llvm::Triple(module_->getTargetTriple());
471 return target_triple.isAMDGPU();
472 }
473
EmitAMDGPUAtomicAdd(llvm::Value * output_address,llvm::Value * source)474 void IrEmitter::EmitAMDGPUAtomicAdd(llvm::Value* output_address,
475 llvm::Value* source) {
476 CHECK(IsEmittingForAMDGPU());
477 auto output_address_type =
478 llvm::dyn_cast<llvm::PointerType>(output_address->getType());
479 CHECK_NE(output_address_type, nullptr);
480
481 auto output_ptr =
482 (output_address_type->getPointerAddressSpace() != 3)
483 ?
484 // the compiler will only generate a global_atomic_fadd if the pointer
485 // is in global addrspace (1)
486 b_.CreateAddrSpaceCast(
487 output_address,
488 llvm::PointerType::getWithSamePointeeType(output_address_type,
489 /*AddressSpace=*/1))
490 :
491 // adds to shared memory are always atomic.
492 output_address;
493
494 AtomicRMW(llvm::AtomicRMWInst::FAdd, output_ptr, source, llvm::MaybeAlign(),
495 llvm::AtomicOrdering::SequentiallyConsistent,
496 b_.getContext().getOrInsertSyncScopeID("agent"));
497 }
498
DetermineSyncScope() const499 llvm::SyncScope::ID IrEmitter::DetermineSyncScope() const {
500 return (IsEmittingForAMDGPU())
501 ? b_.getContext().getOrInsertSyncScopeID("agent")
502 : llvm::SyncScope::System;
503 }
504
505 namespace {
Real(llvm::Value * x,llvm::IRBuilder<> * b)506 llvm::Value* Real(llvm::Value* x, llvm::IRBuilder<>* b) {
507 return b->CreateExtractValue(x, {0});
508 }
509
Imag(llvm::Value * x,llvm::IRBuilder<> * b)510 llvm::Value* Imag(llvm::Value* x, llvm::IRBuilder<>* b) {
511 return b->CreateExtractValue(x, {1});
512 }
513
MultiplyComplex(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b)514 std::pair<llvm::Value*, llvm::Value*> MultiplyComplex(llvm::Value* lhs_value,
515 llvm::Value* rhs_value,
516 llvm::IRBuilder<>* b) {
517 llvm::Value* lhs_real = Real(lhs_value, b);
518 llvm::Value* lhs_imag = Imag(lhs_value, b);
519 llvm::Value* rhs_real = Real(rhs_value, b);
520 llvm::Value* rhs_imag = Imag(rhs_value, b);
521 llvm::Value* real_result1 = b->CreateFMul(lhs_real, rhs_real);
522 llvm::Value* real_result2 = b->CreateFMul(lhs_imag, rhs_imag);
523 llvm::Value* real_result = b->CreateFSub(real_result1, real_result2);
524 llvm::Value* imag_result1 = b->CreateFMul(lhs_real, rhs_imag);
525 llvm::Value* imag_result2 = b->CreateFMul(lhs_imag, rhs_real);
526 llvm::Value* imag_result = b->CreateFAdd(imag_result1, imag_result2);
527 return {real_result, imag_result};
528 }
529 } // namespace
530
HandleConvolution(HloInstruction * convolution)531 Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
532 if (ShapeUtil::IsZeroElementArray(convolution->shape())) {
533 // Emit no code for an empty output.
534 return OkStatus();
535 }
536 // TODO(b/31409998): Support convolution with dilation.
537 return Unimplemented(
538 "Hit a case for convolution that is not implemented on GPU.");
539 }
540
HandleFft(HloInstruction * fft)541 Status IrEmitter::HandleFft(HloInstruction* fft) {
542 if (ShapeUtil::IsZeroElementArray(fft->shape())) {
543 // Emit no code for an empty output.
544 return OkStatus();
545 }
546 return Unimplemented("Hit a case for fft that is not implemented on GPU.");
547 }
548
HandleAllReduce(HloInstruction * crs)549 Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
550 return Unimplemented(
551 "AllReduce cannot be nested inside of fusion, map, etc.");
552 }
553
HandleParameter(HloInstruction * parameter)554 Status IrEmitter::HandleParameter(HloInstruction* parameter) {
555 return OkStatus();
556 }
557
HandleFusion(HloInstruction * fusion)558 Status IrEmitter::HandleFusion(HloInstruction* fusion) {
559 // kFusion for library calls should be handled by
560 // IrEmitterUnnested::HandleFusion.
561 CHECK_EQ(HloInstruction::FusionKind::kLoop, fusion->fusion_kind());
562 GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
563 GetNestedComputer());
564 FusedIrEmitter fused_emitter(elemental_emitter);
565 BindFusionArguments(fusion, &fused_emitter);
566 TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator(
567 *fusion->fused_expression_root()));
568 return EmitTargetElementLoop(*fusion, generator);
569 }
570
HandleCall(HloInstruction * call)571 Status IrEmitter::HandleCall(HloInstruction* call) {
572 std::vector<llvm::Value*> operand_addresses;
573 for (HloInstruction* operand : call->operands()) {
574 operand_addresses.push_back(GetBasePointer(*operand));
575 }
576 return EmitCallToNestedComputation(*call->to_apply(), operand_addresses,
577 GetBasePointer(*call));
578 }
579
HandleCustomCall(HloInstruction *)580 Status IrEmitter::HandleCustomCall(HloInstruction*) {
581 return Unimplemented("custom-call");
582 }
583
HandleInfeed(HloInstruction *)584 Status IrEmitter::HandleInfeed(HloInstruction*) {
585 // TODO(b/30467474): Implement infeed on GPU.
586 return Unimplemented("Infeed is not supported on GPU.");
587 }
588
HandleOutfeed(HloInstruction *)589 Status IrEmitter::HandleOutfeed(HloInstruction*) {
590 // TODO(b/34359662): Implement outfeed on GPU.
591 return Unimplemented("Outfeed is not supported on GPU.");
592 }
593
HandleBatchNormInference(HloInstruction *)594 Status IrEmitter::HandleBatchNormInference(HloInstruction*) {
595 return Unimplemented(
596 "The GPU backend does not implement BatchNormInference directly. It "
597 "should be lowered before IR emission to HLO-soup using "
598 "BatchNormRewriter.");
599 }
600
HandleBatchNormTraining(HloInstruction *)601 Status IrEmitter::HandleBatchNormTraining(HloInstruction*) {
602 return Unimplemented(
603 "The GPU backend does not implement BatchNormTraining directly. It "
604 "should be lowered before IR emission to HLO-soup using "
605 "BatchNormRewriter.");
606 }
607
HandleBatchNormGrad(HloInstruction *)608 Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
609 return Unimplemented(
610 "The GPU backend does not implement BatchNormGrad directly. It should "
611 "be lowered before IR emission to HLO-soup using BatchNormRewriter.");
612 }
613
ComputeNestedElement(const HloComputation & computation,absl::Span<llvm::Value * const> parameter_elements)614 StatusOr<std::vector<llvm::Value*>> IrEmitter::ComputeNestedElement(
615 const HloComputation& computation,
616 absl::Span<llvm::Value* const> parameter_elements) {
617 std::vector<llvm::Value*> parameter_buffers;
618 for (llvm::Value* parameter_element : parameter_elements) {
619 parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
620 parameter_element->getType(), "parameter_buffer", &b_));
621 Store(parameter_element, parameter_buffers.back());
622 }
623
624 return ComputeNestedElementFromAddrs(computation, parameter_buffers);
625 }
626
ComputeNestedElementFromAddrs(const HloComputation & computation,absl::Span<llvm::Value * const> parameter_elements_addrs)627 StatusOr<std::vector<llvm::Value*>> IrEmitter::ComputeNestedElementFromAddrs(
628 const HloComputation& computation,
629 absl::Span<llvm::Value* const> parameter_elements_addrs) {
630 const Shape& return_shape = computation.root_instruction()->shape();
631 llvm::Type* return_buffer_type =
632 llvm_ir::ShapeToIrType(return_shape, module_);
633 llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
634 return_buffer_type, "return_buffer", &b_);
635
636 std::vector<llvm::Value*> allocas_for_returned_scalars;
637 if (!return_shape.IsTuple()) {
638 allocas_for_returned_scalars.push_back(return_buffer);
639 } else {
640 allocas_for_returned_scalars =
641 llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_);
642 llvm_ir::IrArray tuple_array(return_buffer, return_buffer_type,
643 return_shape);
644
645 EmitTuple(tuple_array, allocas_for_returned_scalars, &b_);
646 }
647
648 TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
649 computation, parameter_elements_addrs, return_buffer));
650
651 std::vector<llvm::Value*> returned_scalars;
652 returned_scalars.reserve(allocas_for_returned_scalars.size());
653 for (llvm::Value* addr : allocas_for_returned_scalars) {
654 auto alloca = llvm::cast<llvm::AllocaInst>(addr);
655 returned_scalars.push_back(Load(alloca->getAllocatedType(), alloca));
656 }
657 return returned_scalars;
658 }
659
ConstructIrArrayForOutputs(const HloInstruction & hlo)660 std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs(
661 const HloInstruction& hlo) {
662 std::vector<llvm_ir::IrArray> output_arrays;
663 if (hlo.shape().IsTuple()) {
664 int64_t num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
665 output_arrays.reserve(num_outputs);
666 for (int64_t i = 0; i < num_outputs; ++i) {
667 output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
668 }
669 } else {
670 output_arrays.push_back(GetIrArray(hlo, hlo));
671 }
672 return output_arrays;
673 }
674
BindFusionArguments(const HloInstruction * fusion,FusedIrEmitter * fused_emitter)675 void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
676 FusedIrEmitter* fused_emitter) {
677 for (int i = 0; i < fusion->operand_count(); i++) {
678 const HloInstruction* operand = fusion->operand(i);
679 fused_emitter->BindGenerator(
680 *fusion->fused_parameter(i),
681 [this, operand, fusion](llvm_ir::IrArray::Index index) {
682 return GetIrArray(*operand, *fusion)
683 .EmitReadArrayElement(index, &b_, operand->name());
684 });
685 }
686 }
687
688 } // namespace gpu
689 } // namespace xla
690