xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/ir_emitter.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
17 
18 #include "tensorflow/core/platform/logging.h"
19 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
20 #include "absl/algorithm/container.h"
21 #include "llvm/IR/BasicBlock.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/DerivedTypes.h"
24 #include "llvm/IR/Instructions.h"
25 #include "llvm/IR/Module.h"
26 #include "tensorflow/compiler/xla/primitive_util.h"
27 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
28 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
29 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
30 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
31 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
32 #include "tensorflow/compiler/xla/service/hlo_computation.h"
33 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
34 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
35 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
36 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
37 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
38 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
39 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
40 #include "tensorflow/compiler/xla/service/name_uniquer.h"
41 #include "tensorflow/compiler/xla/shape_util.h"
42 #include "tensorflow/compiler/xla/status_macros.h"
43 #include "tensorflow/compiler/xla/types.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/compiler/xla/window_util.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 
48 // Convenient function to cast the provided llvm::Value* using IRBuilder
49 // to default address space. This is useful in particular for generating
50 // IR for AMDGPU target, as its kernel variables are in address space 5
51 // instead of the default address space.
AddrCastToDefault(llvm::Value * arg,llvm::IRBuilder<> & b)52 static llvm::Value* AddrCastToDefault(llvm::Value* arg, llvm::IRBuilder<>& b) {
53   llvm::Type* arg_type = arg->getType();
54   CHECK(arg_type->isPointerTy());
55   if (arg_type->getPointerAddressSpace() != 0) {
56     llvm::Type* generic_arg_type = llvm::PointerType::getWithSamePointeeType(
57         llvm::cast<llvm::PointerType>(arg_type), 0);
58     llvm::Value* addrspacecast_arg =
59         b.CreateAddrSpaceCast(arg, generic_arg_type);
60     return addrspacecast_arg;
61   }
62   return arg;
63 }
64 
65 namespace xla {
66 
67 using llvm_ir::IrName;
68 using llvm_ir::SetToFirstInsertPoint;
69 
70 namespace gpu {
71 
IrEmitter(const HloModuleConfig & hlo_module_config,IrEmitterContext * ir_emitter_context,bool is_nested)72 IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config,
73                      IrEmitterContext* ir_emitter_context, bool is_nested)
74     : ir_emitter_context_(ir_emitter_context),
75       module_(ir_emitter_context->llvm_module()),
76       b_(module_->getContext()),
77       bindings_(&b_, module_, is_nested),
78       hlo_module_config_(hlo_module_config) {}
79 
DefaultAction(HloInstruction * hlo)80 Status IrEmitter::DefaultAction(HloInstruction* hlo) {
81   ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
82   for (const HloInstruction* operand : hlo->operands()) {
83     operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
84       return GetIrArray(*operand, *hlo)
85           .EmitReadArrayElement(index, &b_, operand->name());
86     };
87   }
88   return EmitTargetElementLoop(
89       *hlo, GpuElementalIrEmitter(hlo_module_config_, module_, &b_,
90                                   GetNestedComputer())
91                 .MakeElementGenerator(hlo, operand_to_generator));
92 }
93 
HandleConstant(HloInstruction * constant)94 Status IrEmitter::HandleConstant(HloInstruction* constant) {
95   return OkStatus();
96 }
97 
HandleAddDependency(HloInstruction * add_dependency)98 Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
99   VLOG(2) << "HandleAddDependency: " << add_dependency->ToString();
100   const HloInstruction* operand = add_dependency->operand(0);
101   // Add_Dependency is a no-op, but we still want to bind it to an llvm::Value
102   // sometimes, e.g., when it's operand is a constant or a bitcast of a
103   // constant.
104   if (bindings_.BoundToIrValue(*operand)) {
105     bindings_.BindHloToIrValue(*add_dependency, GetBasePointer(*operand));
106   }
107   return OkStatus();
108 }
109 
HandleGetTupleElement(HloInstruction * get_tuple_element)110 Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
111   auto operand = get_tuple_element->operand(0);
112   CHECK(bindings_.BoundToIrValue(*operand));
113   bindings_.BindHloToIrValue(
114       *get_tuple_element,
115       llvm_ir::EmitGetTupleElement(
116           get_tuple_element->shape(), get_tuple_element->tuple_index(),
117           // TODO(b/26344050): tighten the alignment here
118           // based on the real element type.
119           /*alignment=*/1, GetBasePointer(*operand),
120           llvm_ir::ShapeToIrType(operand->shape(), module_), &b_));
121   return OkStatus();
122 }
123 
HandleSend(HloInstruction *)124 Status IrEmitter::HandleSend(HloInstruction*) {
125   return Unimplemented("Send is not implemented on GPU");
126 }
127 
HandleSendDone(HloInstruction *)128 Status IrEmitter::HandleSendDone(HloInstruction*) {
129   return Unimplemented("Send-Done is not implemented on GPU");
130 }
131 
HandleRecv(HloInstruction *)132 Status IrEmitter::HandleRecv(HloInstruction*) {
133   return Unimplemented("Recv is not implemented on GPU");
134 }
135 
HandleRecvDone(HloInstruction *)136 Status IrEmitter::HandleRecvDone(HloInstruction*) {
137   return Unimplemented("Recv-done is not implemented on GPU");
138 }
139 
HandleScatter(HloInstruction *)140 Status IrEmitter::HandleScatter(HloInstruction*) {
141   return Unimplemented("Scatter is not implemented on GPUs.");
142 }
143 
HandleTuple(HloInstruction * tuple)144 Status IrEmitter::HandleTuple(HloInstruction* tuple) {
145   std::vector<llvm::Value*> base_ptrs;
146   for (const HloInstruction* operand : tuple->operands()) {
147     base_ptrs.push_back(GetBasePointer(*operand));
148   }
149   llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &b_);
150   return OkStatus();
151 }
152 
EmitCallToNestedComputation(const HloComputation & nested_computation,absl::Span<llvm::Value * const> operands,llvm::Value * output)153 Status IrEmitter::EmitCallToNestedComputation(
154     const HloComputation& nested_computation,
155     absl::Span<llvm::Value* const> operands, llvm::Value* output) {
156   TF_RET_CHECK(nested_computation.num_parameters() > 0);
157   llvm::Function*& emitted_function =
158       computation_to_ir_function_[&nested_computation];
159   if (emitted_function == nullptr) {
160     TF_ASSIGN_OR_RETURN(
161         auto ir_emitter_nested,
162         IrEmitterNested::Create(hlo_module_config_, nested_computation,
163                                 ir_emitter_context_));
164     TF_RETURN_IF_ERROR(ir_emitter_nested->CodegenNestedComputation());
165     emitted_function = ir_emitter_nested->GetEmittedFunction();
166   }
167 
168   // Operands are in default address space for non-AMDGPU target.
169   // However for AMDGPU target, addrspacecast alloca variables from
170   // addrspace 5 to addrspace 0 is needed.
171   std::vector<llvm::Value*> arguments;
172   absl::c_transform(
173       operands, std::back_inserter(arguments),
174       [this](llvm::Value* arg) { return AddrCastToDefault(arg, b_); });
175 
176   llvm::Value* casted_output = AddrCastToDefault(output, b_);
177   arguments.push_back(casted_output);
178 
179   Call(emitted_function, arguments);
180 
181   return OkStatus();
182 }
183 
MaybeEmitDirectAtomicOperation(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address)184 bool IrEmitter::MaybeEmitDirectAtomicOperation(
185     const HloComputation& computation, llvm::Value* output_address,
186     llvm::Value* source_address) {
187   CHECK_EQ(2, computation.num_parameters());
188 
189   HloOpcode root_opcode = computation.root_instruction()->opcode();
190   PrimitiveType element_type =
191       computation.root_instruction()->shape().element_type();
192   bool is_atomic_integral = element_type == S32 || element_type == U32 ||
193                             element_type == S64 || element_type == U64;
194   llvm::Value* source =
195       Load(llvm_ir::PrimitiveTypeToIrType(element_type, module_),
196            source_address, "source");
197 
198   // Just passing along RHS -> atomic store.
199   if (computation.instruction_count() == 2 &&
200       root_opcode == HloOpcode::kParameter &&
201       (element_type == F32 || is_atomic_integral) &&
202       computation.root_instruction()->parameter_number() == 1) {
203     llvm::StoreInst* store = Store(source, output_address);
204     store->setAtomic(llvm::AtomicOrdering::Unordered);
205     // Derive a minimum alignment from the type. The optimizer can increase it
206     // later.
207     store->setAlignment(
208         llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(element_type)));
209     return true;
210   }
211 
212   if (computation.instruction_count() != 3) {
213     // We special-case only computations with one computing instruction for now.
214     // Such computation has exactly three instructions given it has two
215     // parameters.
216     return false;
217   }
218 
219   if (root_opcode == HloOpcode::kAdd) {
220     llvm::Triple target_triple = llvm::Triple(module_->getTargetTriple());
221     // NVPTX supports atomicAdd on F32 and integer types.
222     if (target_triple.isNVPTX()) {
223       // "atom.add.f64 requires sm_60 or higher."
224       // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom
225       bool f64_atomic_add_supported =
226           ir_emitter_context_->cuda_compute_capability().IsAtLeast(6);
227       bool atomic_add_supported =
228           element_type == F32 ||
229           (f64_atomic_add_supported && element_type == F64);
230       if (atomic_add_supported) {
231         AtomicRMW(llvm::AtomicRMWInst::FAdd, output_address, source,
232                   llvm::MaybeAlign(),
233                   llvm::AtomicOrdering::SequentiallyConsistent);
234         return true;
235       }
236     }
237 
238     if (IsEmittingForAMDGPU() &&
239         (element_type == F32)) /* is atomic add supported? */ {
240       EmitAMDGPUAtomicAdd(output_address, source);
241       return true;
242     }
243 
244     if (is_atomic_integral) {
245       // integral + integral
246       AtomicRMW(
247           llvm::AtomicRMWInst::Add, output_address, source, llvm::MaybeAlign(),
248           llvm::AtomicOrdering::SequentiallyConsistent, DetermineSyncScope());
249       return true;
250     }
251   }
252 
253   // NVPTX supports atomicMax and atomicMin only on integer types.
254   if (root_opcode == HloOpcode::kMaximum && is_atomic_integral) {
255     // max(integral, integral)
256     auto opcode = primitive_util::IsSignedIntegralType(element_type)
257                       ? llvm::AtomicRMWInst::Max
258                       : llvm::AtomicRMWInst::UMax;
259     AtomicRMW(opcode, output_address, source, llvm::MaybeAlign(),
260               llvm::AtomicOrdering::SequentiallyConsistent,
261               DetermineSyncScope());
262     return true;
263   }
264 
265   if (root_opcode == HloOpcode::kMinimum && is_atomic_integral) {
266     // min(integral, integral)
267     auto opcode = primitive_util::IsSignedIntegralType(element_type)
268                       ? llvm::AtomicRMWInst::Min
269                       : llvm::AtomicRMWInst::UMin;
270     AtomicRMW(opcode, output_address, source, llvm::MaybeAlign(),
271               llvm::AtomicOrdering::SequentiallyConsistent,
272               DetermineSyncScope());
273     return true;
274   }
275 
276   return false;
277 }
278 
279 // Implements atomic binary operations using atomic compare-and-swap
280 // (atomicCAS) as follows:
281 //   1. Reads the value from the memory pointed to by output_address and
282 //     records it as old_output.
283 //   2. Uses old_output as one of the source operand to perform the binary
284 //     operation and stores the result in new_output.
285 //   3. Calls atomicCAS which implements compare-and-swap as an atomic
286 //     operation. In particular, atomicCAS reads the value from the memory
287 //     pointed to by output_address, and compares the value with old_output. If
288 //     the two values equal, new_output is written to the same memory location
289 //     and true is returned to indicate that the atomic operation succeeds.
290 //     Otherwise, the new value read from the memory is returned. In this case,
291 //     the new value is copied to old_output, and steps 2. and 3. are repeated
292 //     until atomicCAS succeeds.
293 //
294 // On Nvidia GPUs, atomicCAS can only operate on 32 bit and 64 bit integers. If
295 // the element type of the binary operation is 32 bits or 64 bits, the integer
296 // type of the same size is used for the atomicCAS operation. On the other hand,
297 // if the element type is smaller than 32 bits, int32_t is used for the
298 // atomicCAS operation. In this case, atomicCAS reads and writes 32 bit values
299 // from the memory, which is larger than the memory size required by the
300 // original atomic binary operation. We mask off the last two bits of the
301 // output_address and use the result as an address to read the 32 bit values
302 // from the memory. This can avoid out of bound memory accesses if tensor
303 // buffers are 4 byte aligned and have a size of 4N, an assumption that the
304 // runtime can guarantee.
305 //
306 // The pseudo code is shown below. Variables *_address are pointers to a memory
307 // region with a size equal to the size of the atomicCAS operation, with the
308 // exception that new_output_address is a pointer to a memory region with a size
309 // equal to the element size of the binary operation.
310 //
311 //   element_size = sizeof(element_type);
312 //   atomic_size = max(32, element_size);
313 //   cas_new_output_address = alloca(atomic_size);
314 //   cas_old_output_address = alloca(atomic_size);
315 //   if (atomic_size != element_size) {
316 //     atomic_address = output_address & ((int64_t)(-4));
317 //     new_output_address = cas_new_output_address + (output_address & 3);
318 //   } else {
319 //     atomic_address = output_address;
320 //     new_output_address = cas_new_output_address;
321 //   }
322 //
323 //   *cas_old_output_address = *atomic_address;
324 //   do {
325 //     *cas_new_output_address = *cas_old_output_address;
326 //     *new_output_address = operation(*new_output_address, *source_address);
327 //     (*cas_old_output_address, success) =
328 //       atomicCAS(atomic_address, *cas_old_output_address,
329 //       *cas_new_output_address);
330 //   } while (!success);
331 //
EmitAtomicOperationUsingCAS(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address,llvm::Type * element_type)332 Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
333                                               llvm::Value* output_address,
334                                               llvm::Value* source_address,
335                                               llvm::Type* element_type) {
336   llvm::PointerType* output_address_type =
337       llvm::dyn_cast<llvm::PointerType>(output_address->getType());
338   CHECK_NE(output_address_type, nullptr);
339   CHECK(output_address_type->isOpaqueOrPointeeTypeMatches(element_type));
340 
341   int element_size = llvm_ir::GetSizeInBits(element_type);
342 
343   int atomic_size = (element_size < 32) ? 32 : element_size;
344   llvm::Type* atomic_type = b_.getIntNTy(atomic_size);
345   llvm::Type* atomic_address_type =
346       atomic_type->getPointerTo(output_address_type->getPointerAddressSpace());
347 
348   // cas_old_output_address and cas_new_output_address point to the scratch
349   // memory where we store the old and new values for the repeated atomicCAS
350   // operations.
351   llvm::AllocaInst* cas_old_output_address = llvm_ir::EmitAllocaAtFunctionEntry(
352       atomic_type, "cas_old_output_address", &b_);
353   llvm::AllocaInst* cas_new_output_address = llvm_ir::EmitAllocaAtFunctionEntry(
354       atomic_type, "cas_new_output_address", &b_);
355 
356   // Emit preparation code to the preheader.
357   llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock();
358 
359   llvm::Value* atomic_memory_address;
360   // binop_output_address points to the scratch memory that stores the
361   // result of the binary operation.
362   llvm::Value* binop_output_address;
363   if (element_size < 32) {
364     // Assume the element size is an integer number of bytes.
365     CHECK_EQ((element_size % sizeof(char)), 0);
366     llvm::Type* address_int_type =
367         module_->getDataLayout().getIntPtrType(output_address_type);
368     atomic_memory_address = PtrToInt(output_address, address_int_type);
369     llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3);
370     llvm::Value* offset = And(atomic_memory_address, mask);
371     mask = llvm::ConstantInt::get(address_int_type, -4);
372     atomic_memory_address = And(atomic_memory_address, mask);
373     atomic_memory_address =
374         IntToPtr(atomic_memory_address, atomic_address_type);
375     binop_output_address =
376         Add(PtrToInt(cas_new_output_address, address_int_type), offset);
377     binop_output_address = IntToPtr(
378         binop_output_address,
379         llvm::PointerType::get(
380             element_type,
381             cas_new_output_address->getType()->getPointerAddressSpace()));
382   } else {
383     atomic_memory_address = b_.CreatePointerBitCastOrAddrSpaceCast(
384         output_address, atomic_address_type);
385     binop_output_address = b_.CreatePointerBitCastOrAddrSpaceCast(
386         cas_new_output_address,
387         llvm::PointerType::get(
388             element_type,
389             cas_new_output_address->getType()->getPointerAddressSpace()));
390   }
391 
392   // Use the value from the memory that atomicCAS operates on to initialize
393   // cas_old_output.
394   llvm::Value* cas_old_output =
395       Load(atomic_type, atomic_memory_address, "cas_old_output");
396   Store(cas_old_output, cas_old_output_address);
397 
398   llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock(
399       b_.GetInsertPoint(), "atomic_op_loop_exit");
400   llvm::BasicBlock* loop_body_bb = llvm::BasicBlock::Create(
401       b_.getContext(), "atomic_op_loop_body", b_.GetInsertBlock()->getParent());
402   b_.SetInsertPoint(loop_body_bb);
403   // Change preheader's successor from loop_exit_bb to loop_body_bb.
404   loop_preheader_bb->getTerminator()->setSuccessor(0, loop_body_bb);
405 
406   // Emit the body of the loop that repeatedly invokes atomicCAS.
407   //
408   // Use cas_old_output to initialize cas_new_output.
409   cas_old_output = Load(cas_old_output_address->getAllocatedType(),
410                         cas_old_output_address, "cas_old_output");
411   Store(cas_old_output, cas_new_output_address);
412   // Emits code to calculate new_output = operation(old_output, source);
413   TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
414       computation, {binop_output_address, source_address},
415       binop_output_address));
416 
417   llvm::Value* cas_new_output = Load(cas_new_output_address->getAllocatedType(),
418                                      cas_new_output_address, "cas_new_output");
419 
420   // If cas_new_output == cas_old_output, we're not asking for anything to
421   // change, so we're done here!
422   llvm::Value* old_eq_new = ICmpEQ(cas_old_output, cas_new_output);
423   llvm::BasicBlock* loop_cas_bb = llvm::BasicBlock::Create(
424       b_.getContext(), "atomic_op_loop_cas", b_.GetInsertBlock()->getParent());
425   CondBr(old_eq_new, loop_exit_bb, loop_cas_bb);
426   b_.SetInsertPoint(loop_cas_bb);
427 
428   // Emit code to perform the atomicCAS operation
429   // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output,
430   //                                       cas_new_output);
431   llvm::Value* ret_value = AtomicCmpXchg(
432       atomic_memory_address, cas_old_output, cas_new_output, llvm::MaybeAlign(),
433       llvm::AtomicOrdering::SequentiallyConsistent,
434       llvm::AtomicOrdering::SequentiallyConsistent, DetermineSyncScope());
435 
436   // Extract the memory value returned from atomicCAS and store it as
437   // cas_old_output.
438   Store(ExtractValue(ret_value, 0, "cas_old_output"), cas_old_output_address);
439   // Extract the success bit returned from atomicCAS and generate a
440   // conditional branch on the success bit.
441   CondBr(ExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb);
442 
443   // Set the insertion point to the exit basic block so that the caller of
444   // this method can continue emitting code to the right place.
445   SetToFirstInsertPoint(loop_exit_bb, &b_);
446   return OkStatus();
447 }
448 
EmitAtomicOperationForNestedComputation(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address,llvm::Type * element_type)449 Status IrEmitter::EmitAtomicOperationForNestedComputation(
450     const HloComputation& computation, llvm::Value* output_address,
451     llvm::Value* source_address, llvm::Type* element_type) {
452   if (computation.num_parameters() != 2) {
453     // TODO(b/30258929): We only accept binary computations so far.
454     return Unimplemented(
455         "We only support atomic functions with exactly two parameters, but "
456         "computation %s has %d.",
457         computation.name(), computation.num_parameters());
458   }
459 
460   if (MaybeEmitDirectAtomicOperation(computation, output_address,
461                                      source_address)) {
462     return OkStatus();
463   }
464 
465   return EmitAtomicOperationUsingCAS(computation, output_address,
466                                      source_address, element_type);
467 }
468 
IsEmittingForAMDGPU() const469 bool IrEmitter::IsEmittingForAMDGPU() const {
470   llvm::Triple target_triple = llvm::Triple(module_->getTargetTriple());
471   return target_triple.isAMDGPU();
472 }
473 
EmitAMDGPUAtomicAdd(llvm::Value * output_address,llvm::Value * source)474 void IrEmitter::EmitAMDGPUAtomicAdd(llvm::Value* output_address,
475                                     llvm::Value* source) {
476   CHECK(IsEmittingForAMDGPU());
477   auto output_address_type =
478       llvm::dyn_cast<llvm::PointerType>(output_address->getType());
479   CHECK_NE(output_address_type, nullptr);
480 
481   auto output_ptr =
482       (output_address_type->getPointerAddressSpace() != 3)
483           ?
484           // the compiler will only generate a global_atomic_fadd if the pointer
485           // is in global addrspace (1)
486           b_.CreateAddrSpaceCast(
487               output_address,
488               llvm::PointerType::getWithSamePointeeType(output_address_type,
489                                                         /*AddressSpace=*/1))
490           :
491           // adds to shared memory are always atomic.
492           output_address;
493 
494   AtomicRMW(llvm::AtomicRMWInst::FAdd, output_ptr, source, llvm::MaybeAlign(),
495             llvm::AtomicOrdering::SequentiallyConsistent,
496             b_.getContext().getOrInsertSyncScopeID("agent"));
497 }
498 
DetermineSyncScope() const499 llvm::SyncScope::ID IrEmitter::DetermineSyncScope() const {
500   return (IsEmittingForAMDGPU())
501              ? b_.getContext().getOrInsertSyncScopeID("agent")
502              : llvm::SyncScope::System;
503 }
504 
505 namespace {
Real(llvm::Value * x,llvm::IRBuilder<> * b)506 llvm::Value* Real(llvm::Value* x, llvm::IRBuilder<>* b) {
507   return b->CreateExtractValue(x, {0});
508 }
509 
Imag(llvm::Value * x,llvm::IRBuilder<> * b)510 llvm::Value* Imag(llvm::Value* x, llvm::IRBuilder<>* b) {
511   return b->CreateExtractValue(x, {1});
512 }
513 
MultiplyComplex(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b)514 std::pair<llvm::Value*, llvm::Value*> MultiplyComplex(llvm::Value* lhs_value,
515                                                       llvm::Value* rhs_value,
516                                                       llvm::IRBuilder<>* b) {
517   llvm::Value* lhs_real = Real(lhs_value, b);
518   llvm::Value* lhs_imag = Imag(lhs_value, b);
519   llvm::Value* rhs_real = Real(rhs_value, b);
520   llvm::Value* rhs_imag = Imag(rhs_value, b);
521   llvm::Value* real_result1 = b->CreateFMul(lhs_real, rhs_real);
522   llvm::Value* real_result2 = b->CreateFMul(lhs_imag, rhs_imag);
523   llvm::Value* real_result = b->CreateFSub(real_result1, real_result2);
524   llvm::Value* imag_result1 = b->CreateFMul(lhs_real, rhs_imag);
525   llvm::Value* imag_result2 = b->CreateFMul(lhs_imag, rhs_real);
526   llvm::Value* imag_result = b->CreateFAdd(imag_result1, imag_result2);
527   return {real_result, imag_result};
528 }
529 }  // namespace
530 
HandleConvolution(HloInstruction * convolution)531 Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
532   if (ShapeUtil::IsZeroElementArray(convolution->shape())) {
533     // Emit no code for an empty output.
534     return OkStatus();
535   }
536   // TODO(b/31409998): Support convolution with dilation.
537   return Unimplemented(
538       "Hit a case for convolution that is not implemented on GPU.");
539 }
540 
HandleFft(HloInstruction * fft)541 Status IrEmitter::HandleFft(HloInstruction* fft) {
542   if (ShapeUtil::IsZeroElementArray(fft->shape())) {
543     // Emit no code for an empty output.
544     return OkStatus();
545   }
546   return Unimplemented("Hit a case for fft that is not implemented on GPU.");
547 }
548 
HandleAllReduce(HloInstruction * crs)549 Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
550   return Unimplemented(
551       "AllReduce cannot be nested inside of fusion, map, etc.");
552 }
553 
HandleParameter(HloInstruction * parameter)554 Status IrEmitter::HandleParameter(HloInstruction* parameter) {
555   return OkStatus();
556 }
557 
HandleFusion(HloInstruction * fusion)558 Status IrEmitter::HandleFusion(HloInstruction* fusion) {
559   // kFusion for library calls should be handled by
560   // IrEmitterUnnested::HandleFusion.
561   CHECK_EQ(HloInstruction::FusionKind::kLoop, fusion->fusion_kind());
562   GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
563                                           GetNestedComputer());
564   FusedIrEmitter fused_emitter(elemental_emitter);
565   BindFusionArguments(fusion, &fused_emitter);
566   TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator(
567                                           *fusion->fused_expression_root()));
568   return EmitTargetElementLoop(*fusion, generator);
569 }
570 
HandleCall(HloInstruction * call)571 Status IrEmitter::HandleCall(HloInstruction* call) {
572   std::vector<llvm::Value*> operand_addresses;
573   for (HloInstruction* operand : call->operands()) {
574     operand_addresses.push_back(GetBasePointer(*operand));
575   }
576   return EmitCallToNestedComputation(*call->to_apply(), operand_addresses,
577                                      GetBasePointer(*call));
578 }
579 
HandleCustomCall(HloInstruction *)580 Status IrEmitter::HandleCustomCall(HloInstruction*) {
581   return Unimplemented("custom-call");
582 }
583 
HandleInfeed(HloInstruction *)584 Status IrEmitter::HandleInfeed(HloInstruction*) {
585   // TODO(b/30467474): Implement infeed on GPU.
586   return Unimplemented("Infeed is not supported on GPU.");
587 }
588 
HandleOutfeed(HloInstruction *)589 Status IrEmitter::HandleOutfeed(HloInstruction*) {
590   // TODO(b/34359662): Implement outfeed on GPU.
591   return Unimplemented("Outfeed is not supported on GPU.");
592 }
593 
HandleBatchNormInference(HloInstruction *)594 Status IrEmitter::HandleBatchNormInference(HloInstruction*) {
595   return Unimplemented(
596       "The GPU backend does not implement BatchNormInference directly.  It "
597       "should be lowered before IR emission to HLO-soup using "
598       "BatchNormRewriter.");
599 }
600 
HandleBatchNormTraining(HloInstruction *)601 Status IrEmitter::HandleBatchNormTraining(HloInstruction*) {
602   return Unimplemented(
603       "The GPU backend does not implement BatchNormTraining directly.  It "
604       "should be lowered before IR emission to HLO-soup using "
605       "BatchNormRewriter.");
606 }
607 
HandleBatchNormGrad(HloInstruction *)608 Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
609   return Unimplemented(
610       "The GPU backend does not implement BatchNormGrad directly.  It should "
611       "be lowered before IR emission to HLO-soup using BatchNormRewriter.");
612 }
613 
ComputeNestedElement(const HloComputation & computation,absl::Span<llvm::Value * const> parameter_elements)614 StatusOr<std::vector<llvm::Value*>> IrEmitter::ComputeNestedElement(
615     const HloComputation& computation,
616     absl::Span<llvm::Value* const> parameter_elements) {
617   std::vector<llvm::Value*> parameter_buffers;
618   for (llvm::Value* parameter_element : parameter_elements) {
619     parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
620         parameter_element->getType(), "parameter_buffer", &b_));
621     Store(parameter_element, parameter_buffers.back());
622   }
623 
624   return ComputeNestedElementFromAddrs(computation, parameter_buffers);
625 }
626 
ComputeNestedElementFromAddrs(const HloComputation & computation,absl::Span<llvm::Value * const> parameter_elements_addrs)627 StatusOr<std::vector<llvm::Value*>> IrEmitter::ComputeNestedElementFromAddrs(
628     const HloComputation& computation,
629     absl::Span<llvm::Value* const> parameter_elements_addrs) {
630   const Shape& return_shape = computation.root_instruction()->shape();
631   llvm::Type* return_buffer_type =
632       llvm_ir::ShapeToIrType(return_shape, module_);
633   llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
634       return_buffer_type, "return_buffer", &b_);
635 
636   std::vector<llvm::Value*> allocas_for_returned_scalars;
637   if (!return_shape.IsTuple()) {
638     allocas_for_returned_scalars.push_back(return_buffer);
639   } else {
640     allocas_for_returned_scalars =
641         llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_);
642     llvm_ir::IrArray tuple_array(return_buffer, return_buffer_type,
643                                  return_shape);
644 
645     EmitTuple(tuple_array, allocas_for_returned_scalars, &b_);
646   }
647 
648   TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
649       computation, parameter_elements_addrs, return_buffer));
650 
651   std::vector<llvm::Value*> returned_scalars;
652   returned_scalars.reserve(allocas_for_returned_scalars.size());
653   for (llvm::Value* addr : allocas_for_returned_scalars) {
654     auto alloca = llvm::cast<llvm::AllocaInst>(addr);
655     returned_scalars.push_back(Load(alloca->getAllocatedType(), alloca));
656   }
657   return returned_scalars;
658 }
659 
ConstructIrArrayForOutputs(const HloInstruction & hlo)660 std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs(
661     const HloInstruction& hlo) {
662   std::vector<llvm_ir::IrArray> output_arrays;
663   if (hlo.shape().IsTuple()) {
664     int64_t num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
665     output_arrays.reserve(num_outputs);
666     for (int64_t i = 0; i < num_outputs; ++i) {
667       output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
668     }
669   } else {
670     output_arrays.push_back(GetIrArray(hlo, hlo));
671   }
672   return output_arrays;
673 }
674 
BindFusionArguments(const HloInstruction * fusion,FusedIrEmitter * fused_emitter)675 void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
676                                     FusedIrEmitter* fused_emitter) {
677   for (int i = 0; i < fusion->operand_count(); i++) {
678     const HloInstruction* operand = fusion->operand(i);
679     fused_emitter->BindGenerator(
680         *fusion->fused_parameter(i),
681         [this, operand, fusion](llvm_ir::IrArray::Index index) {
682           return GetIrArray(*operand, *fusion)
683               .EmitReadArrayElement(index, &b_, operand->name());
684         });
685   }
686 }
687 
688 }  // namespace gpu
689 }  // namespace xla
690