xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/ir_emitter.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
17 
18 #include <stddef.h>
19 #include <stdint.h>
20 
21 #include <algorithm>
22 #include <iterator>
23 #include <limits>
24 #include <memory>
25 #include <string>
26 #include <utility>
27 #include <vector>
28 
29 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
30 #include "absl/cleanup/cleanup.h"
31 #include "absl/container/flat_hash_map.h"
32 #include "absl/container/flat_hash_set.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/strings/str_format.h"
35 #include "absl/strings/str_join.h"
36 #include "absl/strings/string_view.h"
37 #include "absl/types/span.h"
38 #include "llvm/CodeGen/TargetRegisterInfo.h"
39 #include "llvm/CodeGen/TargetSubtargetInfo.h"
40 #include "llvm/IR/BasicBlock.h"
41 #include "llvm/IR/Constants.h"
42 #include "llvm/IR/FMF.h"
43 #include "llvm/IR/GlobalVariable.h"
44 #include "llvm/IR/Instructions.h"
45 #include "llvm/IR/Intrinsics.h"
46 #include "llvm/IR/IntrinsicsX86.h"
47 #include "llvm/IR/LLVMContext.h"
48 #include "llvm/IR/Value.h"
49 #include "tensorflow/compiler/xla/layout_util.h"
50 #include "tensorflow/compiler/xla/map_util.h"
51 #include "tensorflow/compiler/xla/primitive_util.h"
52 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
53 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
54 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
55 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
56 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
57 #include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h"
58 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
59 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
60 #include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h"
61 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
62 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
63 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
64 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
65 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
66 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
67 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
68 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
69 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
70 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
71 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
72 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_type_conversion_util.h"
73 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
74 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
75 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
76 #include "tensorflow/compiler/xla/shape_util.h"
77 #include "tensorflow/compiler/xla/status_macros.h"
78 #include "tensorflow/compiler/xla/types.h"
79 #include "tensorflow/compiler/xla/util.h"
80 #include "tensorflow/compiler/xla/window_util.h"
81 #include "tensorflow/compiler/xla/xla_data.pb.h"
82 #include "tensorflow/core/lib/core/errors.h"
83 #include "tensorflow/core/lib/math/math_util.h"
84 #include "tensorflow/core/platform/logging.h"
85 
86 namespace xla {
87 
88 namespace {
89 using llvm_ir::IrName;
90 using llvm_ir::SetToFirstInsertPoint;
91 }  // namespace
92 
93 namespace cpu {
94 
IrEmitter(mlir::MLIRContext * mlir_context,const HloModule & hlo_module,const BufferAssignment & assignment,llvm::Module * llvm_module,absl::flat_hash_map<const HloInstruction *,int64_t> instruction_to_profile_idx,absl::flat_hash_map<const HloComputation *,int64_t> computation_to_profile_idx,absl::flat_hash_map<const HloComputation *,bool> computation_transitively_contains_custom_call,const TargetMachineFeatures * target_machine_features,bool emit_code_for_msan)95 IrEmitter::IrEmitter(mlir::MLIRContext* mlir_context,
96                      const HloModule& hlo_module,
97                      const BufferAssignment& assignment,
98                      llvm::Module* llvm_module,
99                      absl::flat_hash_map<const HloInstruction*, int64_t>
100                          instruction_to_profile_idx,
101                      absl::flat_hash_map<const HloComputation*, int64_t>
102                          computation_to_profile_idx,
103                      absl::flat_hash_map<const HloComputation*, bool>
104                          computation_transitively_contains_custom_call,
105                      const TargetMachineFeatures* target_machine_features,
106                      bool emit_code_for_msan)
107     : assignment_(assignment),
108       module_(llvm_module),
109       arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()),
110       b_(llvm_module->getContext()),
111       mlir_context_(mlir_context),
112       instruction_to_profile_idx_(std::move(instruction_to_profile_idx)),
113       computation_to_profile_idx_(std::move(computation_to_profile_idx)),
114       computation_transitively_contains_custom_call_(
115           std::move(computation_transitively_contains_custom_call)),
116       alias_analysis_(hlo_module, assignment, &llvm_module->getContext()),
117       hlo_module_config_(hlo_module.config()),
118       is_top_level_computation_(false),
119       target_machine_features_(*target_machine_features),
120       emit_code_for_msan_(emit_code_for_msan) {
121   b_.setFastMathFlags(llvm_ir::GetCpuFastMathFlags(hlo_module_config_));
122   Status s = GatherComputationsByAllocationType(
123       &hlo_module, &thread_local_computations_, &global_computations_);
124   absl::c_sort(thread_local_computations_);
125   absl::c_sort(global_computations_);
126   TF_CHECK_OK(s) << "Should have failed buffer assignment.";
127 }
128 
EmitThreadLocalFunctionEpilogue(HloComputation * computation)129 void IrEmitter::EmitThreadLocalFunctionEpilogue(HloComputation* computation) {
130   llvm::Argument* out_parameter = compute_function_->result_arg();
131   llvm_ir::IrArray root_value = GetIrArrayFor(computation->root_instruction());
132   const Shape& return_shape = computation->root_instruction()->shape();
133 
134   if (ShapeUtil::IsScalar(return_shape)) {
135     llvm::Value* ret_value =
136         Load(root_value.GetBasePointeeType(), root_value.GetBasePointer(),
137              "load_ret_value");
138     Store(ret_value,
139           BitCast(out_parameter, root_value.GetBasePointer()->getType()));
140   } else {
141     CHECK(return_shape.IsTuple());
142 
143     llvm::Type* tuple_type = llvm_ir::ShapeToIrType(return_shape, module_);
144     llvm::Type* tuple_type_lvalue = tuple_type->getPointerTo();
145     llvm::Value* tuple_lvalue = BitCast(out_parameter, tuple_type_lvalue);
146 
147     for (int i = 0; i < return_shape.tuple_shapes_size(); i++) {
148       const Shape& element_shape = return_shape.tuple_shapes(i);
149       llvm::Value* destination = llvm_ir::EmitGetTupleElement(
150           element_shape,
151           /*index=*/i,
152           /*alignment=*/MinimumAlignmentForShape(element_shape), tuple_lvalue,
153           tuple_type, &b_);
154 
155       llvm::Value* source = llvm_ir::EmitGetTupleElement(
156           element_shape,
157           /*index=*/i,
158           /*alignment=*/MinimumAlignmentForShape(element_shape),
159           root_value.GetBasePointer(), root_value.GetBasePointeeType(), &b_);
160 
161       Store(Load(IrShapeType(element_shape), source), destination);
162     }
163   }
164 }
165 
EmitComputation(HloComputation * computation,const std::string & function_name_prefix,bool is_top_level_computation,absl::Span<HloInstruction * const> instruction_order,bool allow_reassociation)166 StatusOr<llvm::Function*> IrEmitter::EmitComputation(
167     HloComputation* computation, const std::string& function_name_prefix,
168     bool is_top_level_computation,
169     absl::Span<HloInstruction* const> instruction_order,
170     bool allow_reassociation) {
171   std::string function_name = name_uniquer_.GetUniqueName(function_name_prefix);
172   VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]";
173   is_top_level_computation_ = is_top_level_computation;
174   allow_reassociation_ = allow_reassociation;
175   num_dynamic_loop_bounds_ = 0;
176   if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
177     num_dynamic_loop_bounds_ =
178         computation->root_instruction()->outer_dimension_partitions().size();
179   }
180 
181   if (computation->root_instruction()->opcode() != HloOpcode::kOutfeed) {
182     TF_ASSIGN_OR_RETURN(
183         computation_root_allocation_,
184         assignment_.GetUniqueTopLevelSlice(computation->root_instruction()));
185   }
186 
187   bool has_thread_local_param = false;
188   for (const HloInstruction* param : computation->parameter_instructions()) {
189     TF_ASSIGN_OR_RETURN(BufferAllocation::Slice param_slice,
190                         assignment_.GetUniqueTopLevelSlice(param));
191     has_thread_local_param |= param_slice.allocation()->is_thread_local();
192     computation_parameter_allocations_[param_slice.allocation()->index()] =
193         param->parameter_number();
194   }
195 
196   InitializeIrFunction(function_name);
197   // The rdtscp instruction is x86 specific.  We will fallback to LLVM's generic
198   // readcyclecounter if it is unavailable.
199   bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 ||
200                     arch_type_ == llvm::Triple::ArchType::x86_64;
201   profiling_state_ = ProfilingState(use_rdtscp);
202 
203   tracing_state_.set_enabled(
204       computation->parent()->config().cpu_traceme_enabled());
205 
206   llvm::IRBuilderBase::FastMathFlagGuard guard(*builder());
207   llvm::FastMathFlags flags = builder()->getFastMathFlags();
208   flags.setAllowReassoc(flags.allowReassoc() || allow_reassociation);
209   builder()->setFastMathFlags(flags);
210 
211   TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, instruction_order));
212   llvm::Function* ir_function = compute_function_->function();
213   InsertOrDie(&emitted_functions_,
214               ComputationToEmit{computation, allow_reassociation}, ir_function);
215   // Delete 'compute_function', finalizing 'ir_function' and restoring caller
216   // IR insert point.
217 
218   // Function epilogue: copying the value over to either the return register,
219   // or values pointing from the return register. If we have an allocation for
220   // the result, we can cue on whether it is thread_local. If it is a constant,
221   // we use the function parameters' allocations to identify a thread_local
222   // function.
223   const BufferAllocation* root_allocation =
224       computation_root_allocation_.allocation();
225   if (root_allocation &&
226       (root_allocation->is_thread_local() ||
227        (root_allocation->is_constant() && has_thread_local_param))) {
228     EmitThreadLocalFunctionEpilogue(computation);
229   }
230 
231   // Destructor for compute_function_ terminates the LLVM function definition.
232   compute_function_.reset();
233   computation_root_allocation_ = BufferAllocation::Slice();
234   computation_parameter_allocations_.clear();
235   return ir_function;
236 }
237 
InitializeIrFunction(const std::string & function_name)238 void IrEmitter::InitializeIrFunction(const std::string& function_name) {
239   // Functions with local linkage get an inlining bonus.  Because we know
240   // a-priori that embedded functions (non-entry functions) will not have its
241   // name resolved, give it local linkage.
242   llvm::Function::LinkageTypes linkage =
243       is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage
244                                 : llvm::GlobalValue::InternalLinkage;
245   // Create and initialize new IrFunction.
246   compute_function_.reset(new IrFunction(function_name, linkage,
247                                          hlo_module_config_, module_, &b_,
248                                          num_dynamic_loop_bounds_));
249 }
250 
~IrEmitter()251 IrEmitter::~IrEmitter() {}
252 
HandleBitcast(HloInstruction * bitcast)253 Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
254   VLOG(2) << "HandleBitcast: " << bitcast->ToString();
255   emitted_value_[bitcast] =
256       BitCast(GetEmittedValueFor(bitcast->operand(0)),
257               IrShapeType(bitcast->shape())->getPointerTo(), IrName(bitcast));
258   return OkStatus();
259 }
260 
EmitGlobalForLiteral(const Literal & literal)261 llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) {
262   llvm::Constant* initializer =
263       llvm_ir::ConvertLiteralToIrConstant(literal, module_);
264   llvm::GlobalVariable* result_global = new llvm::GlobalVariable(
265       /*Module=*/*module_,
266       /*Type=*/initializer->getType(),
267       /*isConstant=*/true,
268       /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
269       /*Initializer=*/initializer,
270       /*Name=*/"");
271   result_global->setAlignment(
272       llvm::Align(MinimumAlignmentForShape(literal.shape())));
273   result_global->setUnnamedAddr(llvm::GlobalVariable::UnnamedAddr::Global);
274   return llvm::ConstantExpr::getBitCast(
275       result_global, IrShapeType(literal.shape())->getPointerTo());
276 }
277 
EmitConstantGlobals()278 Status IrEmitter::EmitConstantGlobals() {
279   for (const BufferAllocation& allocation : assignment_.Allocations()) {
280     if (!allocation.is_constant()) {
281       continue;
282     }
283 
284     const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation);
285     llvm::Constant* global_for_const;
286     auto it = emitted_literals_.find(&literal);
287     if (it != emitted_literals_.end()) {
288       global_for_const = it->second;
289     } else {
290       global_for_const = EmitGlobalForLiteral(literal);
291       InsertOrDie(&emitted_literals_, &literal, global_for_const);
292     }
293 
294     InsertOrDie(&constant_buffer_to_global_, allocation.index(),
295                 global_for_const);
296   }
297 
298   return OkStatus();
299 }
300 
HandleConstant(HloInstruction * constant)301 Status IrEmitter::HandleConstant(HloInstruction* constant) {
302   VLOG(2) << "HandleConstant: " << constant->ToString();
303   // IrEmitter::EmitConstantGlobals has already taken care of emitting the body
304   // of the constant.
305   return EmitTargetAddressForOp(constant);
306 }
307 
HandleCopy(HloInstruction * copy)308 Status IrEmitter::HandleCopy(HloInstruction* copy) {
309   if (copy->shape().IsTuple() ||
310       (copy->shape().IsArray() &&
311        LayoutUtil::Equal(copy->operand(0)->shape().layout(),
312                          copy->shape().layout()))) {
313     // If the layouts are equal this is just a memcpy. kCopy shallow copies a
314     // tuple so just memcpy the top-level buffer for tuples.
315     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy));
316     return EmitMemcpy(*(copy->operand(0)), *copy);
317   } else if (copy->shape().IsArray()) {
318     // Use the elemental emitter for array shapes.
319     return DefaultAction(copy);
320   }
321   return Unimplemented("unsupported operand type %s for copy instruction",
322                        PrimitiveType_Name(copy->shape().element_type()));
323 }
324 
325 // Calculate the alignment of a buffer allocated for a given primitive type.
MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type)326 int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) {
327   int64_t byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
328   DCHECK_GE(byte_size, 0);
329   // Largest scalar is a complex128 so we don't need to worry about the
330   // int64_t->int truncation here.
331   DCHECK_LE(byte_size, 16);
332 
333   // Allocations may be 8-byte aligned if part of a small block.
334   return std::min(int64_t{8}, byte_size);
335 }
336 
ByteSizeOf(const Shape & shape) const337 int64_t IrEmitter::ByteSizeOf(const Shape& shape) const {
338   return llvm_ir::ByteSizeOf(shape, module_->getDataLayout());
339 }
340 
341 // Calculate the alignment of a buffer allocated for a given shape.
MinimumAlignmentForShape(const Shape & shape)342 int IrEmitter::MinimumAlignmentForShape(const Shape& shape) {
343   if (ShapeUtil::IsScalar(shape)) {
344     return MinimumAlignmentForPrimitiveType(shape.element_type());
345   }
346 
347   int64_t buffer_size = ByteSizeOf(shape);
348   DCHECK_GE(buffer_size, 0);
349   DCHECK_LE(buffer_size, SIZE_MAX);
350 
351   return target_machine_features_.minimum_alignment_for_allocation(buffer_size);
352 }
353 
AttachAlignmentMetadataForLoad(llvm::LoadInst * load,const Shape & shape)354 void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
355                                                const Shape& shape) {
356   int alignment = MinimumAlignmentForShape(shape);
357   if (alignment > 1) {
358     llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
359   }
360 }
361 
AttachAlignmentMetadataForLoad(llvm::LoadInst * load,int64_t buffer_size)362 void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
363                                                int64_t buffer_size) {
364   int alignment =
365       target_machine_features_.minimum_alignment_for_allocation(buffer_size);
366   if (alignment > 1) {
367     llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
368   }
369 }
370 
AttachDereferenceableMetadataForLoad(llvm::LoadInst * load,const Shape & shape)371 void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
372                                                      const Shape& shape) {
373   AttachDereferenceableMetadataForLoad(load, ByteSizeOf(shape));
374 }
375 
AttachDereferenceableMetadataForLoad(llvm::LoadInst * load,int64_t buffer_size)376 void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
377                                                      int64_t buffer_size) {
378   if (buffer_size > 0) {
379     llvm_ir::SetDereferenceableMetadataForLoad(load, buffer_size);
380   }
381 }
382 
HandleGetTupleElement(HloInstruction * get_tuple_element)383 Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
384   // A tuple is an array of pointers, one for each operand. Each pointer points
385   // to the output buffer of its corresponding operand. A GetTupleElement
386   // instruction forwards a pointer to the tuple element buffer at the given
387   // index.
388   const HloInstruction* operand = get_tuple_element->operand(0);
389   const Shape& shape = get_tuple_element->shape();
390   emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement(
391       shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape),
392       GetEmittedValueFor(operand), IrShapeType(operand->shape()), &b_);
393   return OkStatus();
394 }
395 
HandleSelect(HloInstruction * select)396 Status IrEmitter::HandleSelect(HloInstruction* select) {
397   auto pred = select->operand(0);
398   TF_RET_CHECK(pred->shape().element_type() == PRED);
399   return DefaultAction(select);
400 }
401 
HandleInfeed(HloInstruction * instruction)402 Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
403   HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
404   VLOG(2) << "HandleInfeed: " << infeed->ToString();
405 
406   // The infeed operation produces a two-element tuple containing data and a
407   // token value. HloInfeedInstruction::infeed_shape gives us the data shape.
408   const Shape& data_shape = infeed->infeed_shape();
409   DCHECK(ShapeUtil::Equal(data_shape,
410                           ShapeUtil::GetTupleElementShape(infeed->shape(), 0)));
411   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(infeed));
412 
413   // Write the tuple index table.
414   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
415                       assignment_.GetUniqueSlice(infeed, {0}));
416   llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape);
417   llvm::Type* data_type = IrShapeType(data_shape);
418   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice,
419                       assignment_.GetUniqueSlice(infeed, {1}));
420   llvm::Value* token_address = EmitBufferPointer(
421       token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1));
422   llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_);
423 
424   if (data_shape.IsTuple()) {
425     TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape));
426 
427     // For a tuple, we first copy each of the internal elements to
428     // their corresponding target locations. We then construct the
429     // tuple outer buffer containing pointers to the internal
430     // elements.
431     std::vector<llvm::Value*> tuple_element_addresses;
432     for (int i = 0; i < data_shape.tuple_shapes_size(); ++i) {
433       TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer,
434                           assignment_.GetUniqueSlice(infeed, {0, i}));
435 
436       const Shape& tuple_element_shape =
437           ShapeUtil::GetTupleElementShape(data_shape, i);
438 
439       // Only the outer tuple buffer's target address is obtained from
440       // GetEmittedValueFor, to handle the case when Infeed is the root
441       // instruction. Target addresses for internal elements can be obtained
442       // from EmitBufferPointer.
443       llvm::Value* tuple_element_address =
444           EmitBufferPointer(buffer, tuple_element_shape);
445 
446       TF_RETURN_IF_ERROR(EmitXfeedTransfer(
447           XfeedKind::kInfeed, tuple_element_shape, tuple_element_address));
448 
449       tuple_element_addresses.push_back(tuple_element_address);
450     }
451 
452     llvm_ir::EmitTuple(llvm_ir::IrArray(data_address, data_type, data_shape),
453                        tuple_element_addresses, &b_);
454   } else {
455     TF_RETURN_IF_ERROR(
456         EmitXfeedTransfer(XfeedKind::kInfeed, data_shape, data_address));
457   }
458 
459   return OkStatus();
460 }
461 
EmitXfeedTransfer(XfeedKind kind,const Shape & shape,llvm::Value * program_buffer_address)462 Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
463                                     llvm::Value* program_buffer_address) {
464   int64_t length = ByteSizeOf(shape);
465   if (length < 0 || length > std::numeric_limits<int32_t>::max()) {
466     return InvalidArgument(
467         "xfeed (infeed or outfeed) buffer length %d is outside the valid "
468         "size range",
469         length);
470   }
471   int32_t length_32 = static_cast<int32_t>(length);
472 
473   int32_t shape_length;
474   TF_ASSIGN_OR_RETURN(
475       llvm::Value * shape_ptr,
476       llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_));
477 
478   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
479 
480   const char* acquire_func_name =
481       kind == XfeedKind::kInfeed
482           ? runtime::kAcquireInfeedBufferForDequeueSymbolName
483           : runtime::kAcquireOutfeedBufferForPopulationSymbolName;
484 
485   // Implementation note: this call informs the runtime that it wants a
486   // buffer of size exactly 'length_32', and the runtime is responsible for
487   // check-failing the process if there is a mismatch, versus passing us
488   // back a buffer that we might overrun.
489   llvm::Value* acquired_pointer =
490       EmitCallToFunc(acquire_func_name,
491                      {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
492                       shape_ptr, b_.getInt32(shape_length)},
493                      i8_ptr_type);
494   if (kind == XfeedKind::kInfeed) {
495     // Copy to the program buffer address from the acquired buffer.
496     MemCpy(program_buffer_address, /*DstAlign=*/llvm::Align(1),
497            acquired_pointer,
498            /*SrcAlign=*/llvm::Align(1), length_32);
499   } else {
500     // Outfeed -- copy from the in-program address to the acquired buffer.
501     MemCpy(acquired_pointer, /*DstAlign=*/llvm::Align(1),
502            program_buffer_address,
503            /*SrcAlign=*/llvm::Align(1), length_32);
504     if (emit_code_for_msan_) {
505       // Mark the outfed data as initialized for msan. The buffer gets read by
506       // the host code, which might be msan-instrumented.
507       // TODO(b/66051036): Run the msan instrumentation pass instead.
508       const llvm::DataLayout& dl = module_->getDataLayout();
509       llvm::Type* intptr_type = b_.getIntPtrTy(dl);
510       EmitCallToFunc(
511           "__msan_unpoison",
512           {acquired_pointer, llvm::ConstantInt::get(intptr_type, length)},
513           b_.getVoidTy());
514     }
515   }
516 
517   const char* release_func_name =
518       kind == XfeedKind::kInfeed
519           ? runtime::kReleaseInfeedBufferAfterDequeueSymbolName
520           : runtime::kReleaseOutfeedBufferAfterPopulationSymbolName;
521   EmitCallToFunc(release_func_name,
522                  {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
523                   acquired_pointer, shape_ptr, b_.getInt32(shape_length)},
524                  b_.getVoidTy());
525 
526   return OkStatus();
527 }
528 
HandleOutfeed(HloInstruction * outfeed)529 Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
530   // Outfeed produces no useful result, but it does return a token[] that can be
531   // threaded through to other side effecting operations to ensure ordering.  In
532   // the IR emitter we treat this token as a normal u8[] and thus need to insert
533   // an entry for it in emitted_value_.
534   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(outfeed));
535 
536   HloInstruction* operand = outfeed->operands()[0];
537   const Shape& operand_shape = operand->shape();
538 
539   llvm::Value* value = GetEmittedValueFor(operand);
540   if (!operand_shape.IsTuple()) {
541     return EmitXfeedTransfer(XfeedKind::kOutfeed, operand_shape, value);
542   }
543 
544   TF_RET_CHECK(!ShapeUtil::IsNestedTuple(operand_shape));
545 
546   for (int i = 0; i < operand_shape.tuple_shapes_size(); ++i) {
547     const Shape& tuple_element_shape =
548         ShapeUtil::GetTupleElementShape(operand_shape, i);
549     llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement(
550         tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape),
551         value, IrShapeType(operand_shape), &b_);
552     TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed,
553                                          tuple_element_shape, tuple_element));
554   }
555 
556   return OkStatus();
557 }
558 
HandleSort(HloInstruction * hlo)559 Status IrEmitter::HandleSort(HloInstruction* hlo) {
560   const HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
561   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort));
562   Shape keys_shape = sort->keys()->shape();
563   PrimitiveType keys_type = keys_shape.element_type();
564   if (!primitive_util::IsArrayType(keys_type)) {
565     return Unimplemented("Element type %s not supported in the Sort op on CPU.",
566                          PrimitiveType_Name(keys_type));
567   }
568   std::vector<llvm::Value*> destination_addresses(sort->operand_count());
569   for (int64_t i = 0; i < sort->operand_count(); ++i) {
570     ShapeIndex shape_index =
571         sort->values_count() > 0 ? ShapeIndex({i}) : ShapeIndex({});
572     const HloInstruction* operand = sort->operand(i);
573     // We assume that the layout of all involved operands and outputs is the
574     // same.
575     TF_RET_CHECK(
576         LayoutUtil::LayoutsInShapesEqual(keys_shape, operand->shape()));
577     TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
578         keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index)));
579 
580     // The sort is implemented in-place, therefore we first copy the operand
581     // buffer to the output buffer if they are not the same.
582     auto destination_buffer = GetAllocationSlice(*sort, shape_index);
583     destination_addresses[i] =
584         EmitBufferPointer(destination_buffer, operand->shape());
585     auto source_address = GetAllocationSlice(*operand);
586     if (destination_buffer != source_address) {
587       int64_t primitive_type_size =
588           ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type());
589       auto source_buffer = GetEmittedValueFor(operand);
590       int64_t size = ByteSizeOf(operand->shape());
591       MemCpy(destination_addresses[i],
592              /*DstAlign=*/llvm::Align(primitive_type_size), source_buffer,
593              /*SrcAlign=*/llvm::Align(primitive_type_size), size);
594     }
595   }
596 
597   // Normalize the shape and the dimension to sort.
598   Shape normalized_keys_shape =
599       ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(keys_shape);
600   auto logical_to_physical =
601       LayoutUtil::MakeLogicalToPhysical(keys_shape.layout());
602   TF_RET_CHECK(sort->sort_dimension() < logical_to_physical.size());
603   int64_t physical_dimension_to_sort =
604       logical_to_physical[sort->sort_dimension()];
605 
606   int64_t sort_dimension_elements =
607       normalized_keys_shape.dimensions(physical_dimension_to_sort);
608   int64_t higher_dimensions = 1;
609   for (int64_t i = 0; i < physical_dimension_to_sort; ++i) {
610     higher_dimensions *= normalized_keys_shape.dimensions(i);
611   }
612   int64_t lower_dimensions = 1;
613   for (int64_t i = normalized_keys_shape.rank() - 1;
614        i > physical_dimension_to_sort; --i) {
615     lower_dimensions *= normalized_keys_shape.dimensions(i);
616   }
617 
618   CHECK(absl::c_binary_search(thread_local_computations_, sort->to_apply()));
619   llvm::Value* values = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
620       b_.getInt8PtrTy(), b_.getInt32(sort->operand_count()), "cc_values_alloca",
621       &b_);
622   llvm::Value* sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
623       b_.getInt32Ty(), b_.getInt32(sort->operand_count()), "cc_sizes_alloca",
624       &b_);
625   for (int64_t i = 0; i < sort->operand_count(); ++i) {
626     llvm::Value* value_as_i8ptr =
627         PointerCast(destination_addresses[i], b_.getInt8PtrTy());
628     llvm::Value* slot_in_values_alloca =
629         ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i);
630     Store(value_as_i8ptr, slot_in_values_alloca);
631     llvm::Value* slot_in_sizes_alloca =
632         ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i);
633     llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType(
634         sort->operand(i)->shape().element_type()));
635     Store(size, slot_in_sizes_alloca);
636   }
637 
638   auto less_than_function =
639       FindOrDie(emitted_functions_,
640                 ComputationToEmit{sort->to_apply(), allow_reassociation_});
641   EmitCallToFunc(
642       runtime::kKeyValueSortSymbolName,
643       {b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
644        b_.getInt64(lower_dimensions), values,
645        b_.getInt32(sort->operand_count()), sizes, b_.getInt1(sort->is_stable()),
646        GetExecutableRunOptionsArgument(), GetProfileCountersArgument(),
647        less_than_function},
648       b_.getVoidTy());
649 
650   if (sort->values_count() > 0) {
651     llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_);
652   }
653   return OkStatus();
654 }
655 
HandleTuple(HloInstruction * tuple)656 Status IrEmitter::HandleTuple(HloInstruction* tuple) {
657   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple));
658   llvm::SmallVector<llvm::Value*> base_ptrs;
659   for (auto operand : tuple->operands()) {
660     base_ptrs.push_back(GetEmittedValueFor(operand));
661   }
662   llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &b_);
663   return OkStatus();
664 }
665 
HandleReduceWindow(HloInstruction * reduce_window)666 Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
667   // Pseudo code for reduce window:
668   //
669   //   for (coordinates O in the output)
670   //     value = init_value;
671   //     for (coordinates W in the window)
672   //       for each index i:
673   //         input coordinates I_i = O_i * stride_i + W_i - pad_low_i
674   //       if I within bounds of input:
675   //         value = function(value, input(I));
676   //     output(O) = value;
677   //
678   // This is completely un-optimized and just here to have something
679   // that works.
680   bool saved_allow_reassociation = allow_reassociation_;
681   allow_reassociation_ = true;
682   Status status = DefaultAction(reduce_window);
683   allow_reassociation_ = saved_allow_reassociation;
684   return status;
685 }
686 
HandleSelectAndScatter(HloInstruction * select_and_scatter)687 Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
688   CHECK_EQ(select_and_scatter->operand_count(), 3);
689   const auto operand = select_and_scatter->operand(0);
690   const auto source = select_and_scatter->operand(1);
691   const auto init_value = select_and_scatter->operand(2);
692   const Window& window = select_and_scatter->window();
693   PrimitiveType operand_element_type = operand->shape().element_type();
694   const int64_t rank = operand->shape().rank();
695   CHECK_EQ(rank, source->shape().rank());
696   CHECK_EQ(rank, window.dimensions_size());
697 
698   // TODO(b/31410564): Implement dilation for select-and-scatter.
699   if (window_util::HasDilation(window)) {
700     return Unimplemented(
701         "Dilation for SelectAndScatter is not implemented on CPU. ");
702   }
703 
704   // Pseudo code for select-and-scatter:
705   //
706   // initialized_flag is initially off for every window, and is turned on after
707   // the first iteration is completed and the first operand value is selected.
708   //
709   // output(*) = init_value
710   // for (coordinates S in the source) {
711   //   initialized_flag = false
712   //   for (coordinates W in the window) {
713   //     I = S * stride + W - pad_low
714   //     if I within bounds of operand:
715   //       if !initialized_flag or select(selected_value, operand(I)) == false:
716   //         selected_value = operand(I)
717   //         selected_index = I
718   //         initialized_flag = true
719   //   }
720   //   output(selected_index) = scatter(output(selected_index), source(S))
721   // }
722   //
723 
724   // Initialize the output array with the given init_value.
725   TF_RETURN_IF_ERROR(EmitTargetElementLoop(
726       select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"),
727       [this, init_value](const llvm_ir::IrArray::Index& target_index) {
728         llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
729         return Load(IrShapeType(init_value->shape()), init_value_addr);
730       }));
731 
732   // Create a loop to iterate over the source array to scatter to the output.
733   llvm_ir::ForLoopNest source_loops(IrName(select_and_scatter), &b_);
734   const llvm_ir::IrArray::Index source_index =
735       source_loops.AddLoopsForShape(source->shape(), "source");
736   SetToFirstInsertPoint(source_loops.GetInnerLoopBodyBasicBlock(), &b_);
737 
738   // Allocate space to keep the currently selected value, its index, and
739   // the boolean initialized_flag, which is initially set to false.
740   llvm::AllocaInst* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
741       llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
742       "selected_value_address", &b_,
743       MinimumAlignmentForPrimitiveType(operand_element_type));
744   llvm::AllocaInst* selected_index_address =
745       llvm_ir::EmitAllocaAtFunctionEntryWithCount(
746           b_.getInt64Ty(), b_.getInt32(rank), "selected_index_address", &b_);
747   llvm::AllocaInst* initialized_flag_address =
748       llvm_ir::EmitAllocaAtFunctionEntry(b_.getInt1Ty(),
749                                          "initialized_flag_address", &b_);
750   Store(b_.getInt1(false), initialized_flag_address);
751 
752   // Create the inner loop to iterate over the window.
753   llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), &b_);
754   llvm::SmallVector<int64_t> window_size;
755   for (const auto& dim : window.dimensions()) {
756     window_size.push_back(dim.size());
757   }
758   const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape(
759       ShapeUtil::MakeShape(operand_element_type, window_size), "window");
760   SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), &b_);
761 
762   // Compute the operand index to visit and evaluate the condition whether the
763   // operand index is within the bounds. The unsigned comparison includes
764   // checking whether the operand index >= 0.
765   llvm::SmallVector<llvm::Value*> operand_multi_index(source_index.size());
766   llvm::Value* in_bounds_condition = b_.getTrue();
767   for (int64_t i = 0; i < rank; ++i) {
768     llvm::Value* strided_index =
769         NSWMul(source_index[i], b_.getInt64(window.dimensions(i).stride()));
770     operand_multi_index[i] =
771         NSWSub(NSWAdd(strided_index, window_index[i]),
772                b_.getInt64(window.dimensions(i).padding_low()));
773     llvm::Value* index_condition =
774         ICmpULT(operand_multi_index[i],
775                 b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
776     in_bounds_condition = And(in_bounds_condition, index_condition);
777   }
778   CHECK(in_bounds_condition != nullptr);
779 
780   // Only need to do something if the operand index is within the bounds. First
781   // check if the initialized_flag is set.
782   llvm_ir::LlvmIfData if_in_bounds =
783       llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
784   SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
785   llvm_ir::LlvmIfData if_initialized =
786       llvm_ir::EmitIfThenElse(Load(initialized_flag_address->getAllocatedType(),
787                                    initialized_flag_address),
788                               "initialized", &b_);
789 
790   // If the initialized_flag is false, initialize the selected value and index
791   // with the currently visiting operand.
792   SetToFirstInsertPoint(if_initialized.false_block, &b_);
793   const auto save_operand_index =
794       [&](const llvm_ir::IrArray::Index& operand_index) {
795         for (int64_t i = 0; i < rank; ++i) {
796           llvm::Value* selected_index_address_slot =
797               InBoundsGEP(selected_index_address->getAllocatedType(),
798                           selected_index_address, {b_.getInt32(i)});
799           Store(operand_index[i], selected_index_address_slot);
800         }
801       };
802   llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
803   llvm_ir::IrArray::Index operand_index(
804       operand_multi_index, operand_array.GetShape(), b_.getInt64Ty());
805   llvm::Value* operand_data =
806       operand_array.EmitReadArrayElement(operand_index, &b_);
807   Store(operand_data, selected_value_address);
808   save_operand_index(operand_index);
809   Store(b_.getInt1(true), initialized_flag_address);
810 
811   // If the initialized_flag is true, call the `select` function to potentially
812   // update the selected value and index with the currently visiting operand.
813   SetToFirstInsertPoint(if_initialized.true_block, &b_);
814   llvm::Value* operand_address =
815       operand_array.EmitArrayElementAddress(operand_index, &b_);
816   llvm::Value* operand_element =
817       Load(operand_array.GetElementLlvmType(), operand_address);
818   llvm::Value* result = EmitScalarReturningThreadLocalCall(
819       *select_and_scatter->select(),
820       {Load(selected_value_address->getAllocatedType(), selected_value_address),
821        operand_element},
822       "select_function");
823 
824   // If the 'select' function returns false, update the selected value and the
825   // index to the currently visiting operand.
826   llvm::Value* cond = ICmpNE(
827       result,
828       llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
829       "boolean_predicate");
830   llvm_ir::LlvmIfData if_select_lhs =
831       llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
832   SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
833   Store(Load(operand_array.GetElementLlvmType(), operand_address),
834         selected_value_address);
835   save_operand_index(operand_index);
836 
837   // After iterating over the window elements, scatter the source element to
838   // the selected index of the output. The value we store at the output
839   // location is computed by calling the `scatter` function with the source
840   // value and the current output value.
841   SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), &b_);
842   llvm::SmallVector<llvm::Value*> selected_multi_index;
843   for (int64_t i = 0; i < rank; ++i) {
844     const std::vector<llvm::Value*> gep_index = {b_.getInt32(i)};
845     llvm::Value* selected_index_address_slot =
846         InBoundsGEP(selected_index_address->getAllocatedType(),
847                     selected_index_address, gep_index);
848     llvm::Type* type = llvm::GetElementPtrInst::getIndexedType(
849         selected_index_address->getAllocatedType(), gep_index);
850     selected_multi_index.push_back(Load(type, selected_index_address_slot));
851   }
852   llvm_ir::IrArray source_array(GetIrArrayFor(source));
853   llvm::Value* source_value =
854       source_array.EmitReadArrayElement(source_index, &b_);
855   llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter));
856   llvm_ir::IrArray::Index selected_index(
857       selected_multi_index, output_array.GetShape(), source_index.GetType());
858   llvm::Value* output_value =
859       output_array.EmitReadArrayElement(selected_index, &b_);
860   llvm::Value* scatter_value = EmitScalarReturningThreadLocalCall(
861       *select_and_scatter->scatter(), {output_value, source_value},
862       "scatter_function");
863   output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_);
864 
865   SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_);
866   return OkStatus();
867 }
868 
HandleDot(HloInstruction * dot)869 Status IrEmitter::HandleDot(HloInstruction* dot) {
870   auto lhs = dot->operand(0);
871   auto rhs = dot->operand(1);
872   TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
873       /*instruction=*/*dot, /*operands=*/{lhs, rhs},
874       /*supported_types=*/
875       {PRED, S8, U8, S16, U16, S32, U32, S64, U64, F16, F32, F64, C64, C128}));
876   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
877 
878   if (dnums.lhs_contracting_dimensions_size() != 1) {
879     // This is disallowed by ShapeInference today.
880     return Unimplemented(
881         "Dot with multiple contracting dimensions not implemented.");
882   }
883 
884   llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs));
885   llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs));
886 
887   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dot));
888   llvm_ir::IrArray target_array = GetIrArrayFor(dot);
889 
890   VLOG(2) << "HandleDot: ";
891   VLOG(2) << "  lhs operand: "
892           << llvm_ir::DumpToString(*lhs_array.GetBasePointer());
893   VLOG(2) << "  rhs operand: "
894           << llvm_ir::DumpToString(*rhs_array.GetBasePointer());
895   VLOG(2) << "  target: "
896           << llvm_ir::DumpToString(*target_array.GetBasePointer());
897 
898   // Dot operation is complicated so we delegate to a helper class.
899   return EmitDotOperation(*dot, target_array, lhs_array, rhs_array,
900                           /*addend_array=*/nullptr,
901                           GetExecutableRunOptionsArgument(), &b_, mlir_context_,
902                           hlo_module_config_, target_machine_features_);
903 }
904 
HandleConvolution(HloInstruction * convolution)905 Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
906   auto lhs = convolution->operand(0);
907   auto rhs = convolution->operand(1);
908   TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
909       /*instruction=*/*convolution, /*operands=*/{lhs, rhs},
910       /*supported_types=*/
911       {PRED, S8, U8, S16, U16, S32, U32, S64, U64, F16, F32, F64, C64, C128}));
912 
913   // TODO(tonywy): Add PotentiallyImplementedAsMKLConvolution to support
914   // different data layouts.
915   if (PotentiallyImplementedAsEigenConvolution(*convolution,
916                                                target_machine_features_)) {
917     const Shape& lhs_shape = lhs->shape();
918     const Shape& rhs_shape = rhs->shape();
919     const Shape& convolution_shape = convolution->shape();
920     // The input, kernel and output agree with respect to layout.
921     if (LayoutUtil::IsMonotonicWithDim0Major(lhs_shape.layout()) &&
922         LayoutUtil::IsMonotonicWithDim0Major(rhs_shape.layout()) &&
923         LayoutUtil::IsMonotonicWithDim0Major(convolution_shape.layout())) {
924       // We lower 1D convolutions into calls to the same Eigen function as 2D
925       // convolutions, except that we pretend that the 1D convolution is really
926       // a 2D convolution with the missing dimension set to 1.  We also adjust
927       // the padding, dilation parameters as needed.
928       bool one_dim_convolution = lhs_shape.dimensions_size() == 3;
929       llvm::Value* lhs_address = GetEmittedValueFor(lhs);
930       llvm::Value* rhs_address = GetEmittedValueFor(rhs);
931       TF_RETURN_IF_ERROR(EmitTargetAddressForOp(convolution));
932 
933       const ConvolutionDimensionNumbers& dnums =
934           convolution->convolution_dimension_numbers();
935 
936       absl::InlinedVector<int64_t, 2> input_dims;
937       absl::InlinedVector<int64_t, 2> kernel_dims;
938       absl::InlinedVector<int64_t, 2> output_dims;
939       if (one_dim_convolution) {
940         input_dims.push_back(1);
941         kernel_dims.push_back(1);
942         output_dims.push_back(1);
943       }
944 
945       // Input tensor.
946       const Shape& input_shape = convolution->operand(0)->shape();
947       int64_t input_batch =
948           input_shape.dimensions(dnums.input_batch_dimension());
949       for (int d : dnums.input_spatial_dimensions()) {
950         input_dims.push_back(input_shape.dimensions(d));
951       }
952       int64_t input_channels =
953           input_shape.dimensions(dnums.input_feature_dimension());
954 
955       // Kernel tensor.
956       const Shape& kernel_shape = convolution->operand(1)->shape();
957       for (int d : dnums.kernel_spatial_dimensions()) {
958         kernel_dims.push_back(kernel_shape.dimensions(d));
959       }
960       int64_t kernel_channels =
961           kernel_shape.dimensions(dnums.kernel_input_feature_dimension());
962       int64_t kernel_filters =
963           kernel_shape.dimensions(dnums.kernel_output_feature_dimension());
964 
965       // Output tensor.
966       const Shape& convolution_shape = convolution->shape();
967       for (int d : dnums.output_spatial_dimensions()) {
968         output_dims.push_back(convolution_shape.dimensions(d));
969       }
970 
971       // Extract the window stride for the convolution.
972       const Window& window = convolution->window();
973       absl::InlinedVector<int64_t, 2> strides;
974       absl::InlinedVector<std::pair<int64_t, int64_t>, 2> padding;
975       absl::InlinedVector<int64_t, 2> base_dilation;
976       absl::InlinedVector<int64_t, 2> window_dilation;
977       if (one_dim_convolution) {
978         strides.push_back(1);
979         padding.push_back({0, 0});
980         base_dilation.push_back(1);
981         window_dilation.push_back(1);
982       }
983       for (const auto& d : window.dimensions()) {
984         strides.push_back(d.stride());
985         padding.push_back({d.padding_low(), d.padding_high()});
986         base_dilation.push_back(d.base_dilation());
987         window_dilation.push_back(d.window_dilation());
988       }
989 
990       PrimitiveType primitive_type = lhs->shape().element_type();
991       llvm::Type* ir_ptr_type = primitive_type == F16
992                                     ? b_.getHalfTy()->getPointerTo()
993                                     : b_.getFloatTy()->getPointerTo();
994       bool multi_threaded =
995           hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
996       bool use_mkl_dnn =
997           hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn() &&
998           convolution->feature_group_count() == 1;
999       bool use_acl = hlo_module_config_.debug_options().xla_cpu_use_acl();
1000 
1001       auto valid_num_dims = [](absl::Span<const int64_t> xs) {
1002         return xs.size() >= 2 && xs.size() <= 3;
1003       };
1004       TF_RET_CHECK(valid_num_dims(input_dims)) << input_dims.size();
1005       TF_RET_CHECK(valid_num_dims(kernel_dims));
1006       TF_RET_CHECK(valid_num_dims(output_dims));
1007       TF_RET_CHECK(valid_num_dims(strides));
1008       TF_RET_CHECK(padding.size() >= 2 && padding.size() <= 3);
1009       TF_RET_CHECK(valid_num_dims(base_dilation));
1010       TF_RET_CHECK(valid_num_dims(window_dilation));
1011 
1012       // TODO(b/78639006) Singlethread MKL conv2d is not implemented due to the
1013       // potential race condition by setting the omp_num_threads.
1014       const char* fn_name;
1015       if (input_dims.size() == 2) {
1016         fn_name =
1017             primitive_type == F16
1018                 ? (multi_threaded
1019                        ? runtime::kEigenConv2DF16SymbolName
1020                        : runtime::kEigenSingleThreadedConv2DF16SymbolName)
1021                 : (multi_threaded
1022                        ? (use_mkl_dnn
1023                               ? runtime::kMKLConv2DF32SymbolName
1024                               : (use_acl ? runtime::kACLConv2DF32SymbolName
1025                                          : runtime::kEigenConv2DF32SymbolName))
1026                        : runtime::kEigenSingleThreadedConv2DF32SymbolName);
1027       } else if (input_dims.size() == 3) {
1028         fn_name =
1029             primitive_type == F16
1030                 ? (multi_threaded
1031                        ? runtime::kEigenConv3DF16SymbolName
1032                        : runtime::kEigenSingleThreadedConv3DF16SymbolName)
1033                 : (multi_threaded
1034                        ? runtime::kEigenConv3DF32SymbolName
1035                        : runtime::kEigenSingleThreadedConv3DF32SymbolName);
1036       } else {
1037         LOG(FATAL) << "Invalid number of dimensions " << input_dims.size();
1038       }
1039       if (!multi_threaded && use_mkl_dnn) {
1040         LOG(WARNING) << "Using Eigen instead of MKL-DNN for single-threaded "
1041                         "convolution.";
1042       }
1043       std::vector<llvm::Value*> args = {
1044           GetExecutableRunOptionsArgument(),
1045           BitCast(GetEmittedValueFor(convolution), ir_ptr_type),
1046           BitCast(lhs_address, ir_ptr_type),
1047           BitCast(rhs_address, ir_ptr_type),
1048           b_.getInt64(input_batch),
1049       };
1050       for (int64_t d : input_dims) {
1051         args.push_back(b_.getInt64(d));
1052       }
1053       args.push_back(b_.getInt64(input_channels));
1054       for (int64_t d : kernel_dims) {
1055         args.push_back(b_.getInt64(d));
1056       }
1057       args.push_back(b_.getInt64(kernel_channels));
1058       args.push_back(b_.getInt64(kernel_filters));
1059       for (int64_t d : output_dims) {
1060         args.push_back(b_.getInt64(d));
1061       }
1062       for (int64_t d : strides) {
1063         args.push_back(b_.getInt64(d));
1064       }
1065       for (const auto& p : padding) {
1066         args.push_back(b_.getInt64(p.first));
1067         args.push_back(b_.getInt64(p.second));
1068       }
1069       for (int64_t d : base_dilation) {
1070         args.push_back(b_.getInt64(d));
1071       }
1072       for (int64_t d : window_dilation) {
1073         args.push_back(b_.getInt64(d));
1074       }
1075       args.push_back(b_.getInt64(convolution->feature_group_count()));
1076 
1077       VLOG(1) << "Ir emitter emitted Convolution to runtime:" << fn_name;
1078       EmitCallToFunc(fn_name, args, b_.getVoidTy(), /*does_not_throw=*/true,
1079                      /*only_accesses_arg_memory=*/true);
1080 
1081       return OkStatus();
1082     }
1083   }
1084   // This is a completely un-optimized version of convolution just to
1085   // have an early version that works. E.g. the input index and
1086   // padding calculation is not hoisted out of the inner loop.
1087   //
1088   // See the description of convolution in the XLA documentation for the pseudo
1089   // code for convolution.
1090   return DefaultAction(convolution);
1091 }
1092 
HandleFft(HloInstruction * fft)1093 Status IrEmitter::HandleFft(HloInstruction* fft) {
1094   auto operand = fft->operand(0);
1095   TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
1096       /*instruction=*/*fft, /*operands=*/{operand},
1097       /*supported_types=*/{F32, F64, C64, C128}));
1098   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout()));
1099   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout()));
1100   VLOG(3) << "operand=" << ShapeUtil::HumanStringWithLayout(operand->shape());
1101   VLOG(3) << "fft=" << ShapeUtil::HumanStringWithLayout(fft->shape());
1102 
1103   llvm::Value* operand_address = GetEmittedValueFor(operand);
1104   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fft));
1105 
1106   const std::vector<int64_t>& fft_length = fft->fft_length();
1107   int64_t input_batch = 1;
1108   for (int i = 0; i < fft->shape().dimensions_size() - fft_length.size(); i++) {
1109     input_batch *= fft->shape().dimensions(i);
1110   }
1111 
1112   // Args have been computed, make the call.
1113   llvm::Type* int8_ptr_type = b_.getInt8Ty()->getPointerTo();
1114   bool multi_threaded_eigen =
1115       hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
1116   const char* fn_name = multi_threaded_eigen
1117                             ? runtime::kEigenFftSymbolName
1118                             : runtime::kEigenSingleThreadedFftSymbolName;
1119   const int fft_rank = fft_length.size();
1120   EmitCallToFunc(
1121       fn_name,
1122       {GetExecutableRunOptionsArgument(),
1123        BitCast(GetEmittedValueFor(fft), int8_ptr_type),
1124        BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()),
1125        b_.getInt32(operand->shape().element_type() == F64 ||
1126                    operand->shape().element_type() == C128),
1127        b_.getInt32(fft_rank), b_.getInt64(input_batch),
1128        b_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
1129        b_.getInt64(fft_rank > 1 ? fft_length[1] : 0),
1130        b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)},
1131       b_.getVoidTy(), /*does_not_throw=*/true,
1132       /*only_accesses_arg_memory=*/false,
1133       /*only_accesses_inaccessible_mem_or_arg_mem=*/true);
1134 
1135   return OkStatus();
1136 }
1137 
HandleAllReduceSingleReplica(HloInstruction * crs)1138 Status IrEmitter::HandleAllReduceSingleReplica(HloInstruction* crs) {
1139   // When there is a single replica, a cross replica sum is the identity
1140   // function, and the buffer assignment expects a copy.
1141   //
1142   // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely
1143   // in algebraic-simplifier, but currently on some platforms
1144   // HloModuleConfig::num_replicas changes between when the module is compiled
1145   // and when it's run.
1146   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs));
1147 
1148   // CRS with one operand and one replica is simply the identity function.
1149   if (crs->operand_count() == 1) {
1150     return EmitMemcpy(*crs->operand(0), *crs);
1151   }
1152 
1153   // CRS with multiple operands and one replica produces a (one-deep) tuple.
1154   std::vector<llvm::Value*> operand_ptrs;
1155   for (int64_t i = 0; i < crs->operand_count(); ++i) {
1156     llvm::Value* in_ptr = GetEmittedValueFor(crs->operand(i));
1157     TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
1158                         assignment_.GetUniqueSlice(crs, {i}));
1159 
1160     const Shape& operand_shape = crs->operand(i)->shape();
1161     CHECK(operand_shape.IsArray())
1162         << "Operands to all-reduce must be arrays: " << crs->ToString();
1163     operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
1164 
1165     // TODO(b/63762267): Be more aggressive about specifying alignment.
1166     MemCpy(operand_ptrs.back(), /*DstAlign=*/llvm::Align(1), in_ptr,
1167            /*SrcAlign=*/llvm::Align(1), ShapeUtil::ByteSizeOf(operand_shape));
1168   }
1169   llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_);
1170   return OkStatus();
1171 }
1172 
HandleAllReduceMultipleReplica(HloInstruction * crs)1173 Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) {
1174   CHECK_GE(crs->operand_count(), 1);
1175   PrimitiveType datatype = crs->operand(0)->shape().element_type();
1176   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs));
1177 
1178   bool is_datatype_supported = [&] {
1179     // TODO(cheshire): Fix duplication wrt. cpu_runtime
1180     switch (datatype) {
1181       case PRED:
1182       case S8:
1183       case U8:
1184       case S32:
1185       case U32:
1186       case S64:
1187       case U64:
1188       case F16:
1189       case F32:
1190       case F64:
1191       case C64:
1192       case C128:
1193         return true;
1194       default:
1195         return false;
1196     }
1197   }();
1198 
1199   if (!is_datatype_supported) {
1200     return Unimplemented("AllReduce for datatype '%s' is not supported",
1201                          primitive_util::LowercasePrimitiveTypeName(datatype));
1202   }
1203 
1204   if (!MatchReductionComputation(crs->to_apply()).has_value()) {
1205     return Unimplemented("AllReduce for computation '%s' is not supported",
1206                          crs->to_apply()->ToString());
1207   }
1208 
1209   std::string replica_groups = ReplicaGroupsToString(crs->replica_groups());
1210   int32_t replica_groups_size = replica_groups.size();
1211   llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups);
1212 
1213   bool is_tuple = crs->operand_count() > 1;
1214   std::vector<llvm::Value*> input_buffer_ptrs;
1215   std::vector<llvm::Value*> output_buffer_ptrs;
1216   if (is_tuple) {
1217     CHECK(crs->shape().IsTuple());
1218 
1219     for (int64_t i = 0; i < crs->operand_count(); i++) {
1220       const HloInstruction* op = crs->operand(i);
1221       TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
1222                           assignment_.GetUniqueSlice(crs, {i}));
1223       const Shape& operand_shape = crs->operand(i)->shape();
1224       CHECK(operand_shape.IsArray())
1225           << "Operands to all-reduce must be arrays: " << crs->ToString();
1226       output_buffer_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
1227       input_buffer_ptrs.push_back(GetEmittedValueFor(op));
1228     }
1229   } else {
1230     Shape shape = crs->operand(0)->shape();
1231     TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
1232                         assignment_.GetUniqueSlice(crs->operand(0), {}));
1233     TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1234                         assignment_.GetUniqueSlice(crs, {}));
1235     input_buffer_ptrs.push_back(EmitBufferPointer(input_slice, shape));
1236     output_buffer_ptrs.push_back(EmitBufferPointer(output_slice, shape));
1237   }
1238 
1239   llvm::Value* input_buffers =
1240       EncodeArrayFunctionArguments(input_buffer_ptrs, "input_buffers", &b_);
1241   llvm::Value* output_buffers =
1242       EncodeArrayFunctionArguments(output_buffer_ptrs, "output_buffers", &b_);
1243 
1244   int32_t shape_length;
1245   TF_ASSIGN_OR_RETURN(llvm::Value * shape_ptr,
1246                       llvm_ir::EncodeSelfDescribingShapeConstant(
1247                           crs->shape(), &shape_length, &b_));
1248 
1249   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1250   bool use_global_device_ids =
1251       Cast<HloAllReduceInstruction>(crs)->use_global_device_ids();
1252   EmitCallToFunc(
1253       runtime::kAllReduceSymbolName,
1254       {/*run_options=*/GetExecutableRunOptionsArgument(),
1255        /*replica_groups=*/replica_groups_v,
1256        /*replica_groups_size=*/b_.getInt32(replica_groups_size),
1257 
1258        /*channel_id_present=*/
1259        b_.getInt32(static_cast<int32_t>(crs->channel_id().has_value())),
1260        /*use_global_device_ids=*/
1261        b_.getInt32(static_cast<int32_t>(use_global_device_ids)),
1262        /*op_id=*/
1263        b_.getInt64(crs->channel_id().has_value()
1264                        ? *crs->channel_id()
1265                        : crs->GetModule()->unique_id()),
1266        /*reduction_kind=*/
1267        b_.getInt32(
1268            static_cast<int32_t>(*MatchReductionComputation(crs->to_apply()))),
1269        /*shape_ptr=*/shape_ptr,
1270        /*shape_length=*/b_.getInt32(shape_length),
1271        /*num_buffers=*/b_.getInt32(crs->operand_count()),
1272        /*input_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type),
1273        /*output_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)},
1274       b_.getVoidTy());
1275 
1276   return OkStatus();
1277 }
1278 
HandleAllReduce(HloInstruction * crs)1279 Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
1280   if (hlo_module_config_.replica_count() == 1 &&
1281       hlo_module_config_.num_partitions() == 1) {
1282     return HandleAllReduceSingleReplica(crs);
1283   }
1284   return HandleAllReduceMultipleReplica(crs);
1285 }
1286 
HandleAllToAll(HloInstruction * instruction)1287 Status IrEmitter::HandleAllToAll(HloInstruction* instruction) {
1288   auto* instr = Cast<HloAllToAllInstruction>(instruction);
1289   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction));
1290   CHECK(!instr->split_dimension() && instr->shape().IsTuple())
1291       << "Only tuple AllToAll is supported";
1292 
1293   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1294   std::string replica_groups =
1295       ReplicaGroupsToString(instruction->replica_groups());
1296   int32_t replica_groups_size = replica_groups.size();
1297   llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups);
1298 
1299   int64_t buffer_size = -1;
1300   std::vector<llvm::Value*> input_buffer_ptrs;
1301   std::vector<llvm::Value*> output_buffer_ptrs;
1302 
1303   for (int64_t i = 0; i < instruction->operand_count(); i++) {
1304     const HloInstruction* op = instruction->operand(i);
1305     TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
1306                         assignment_.GetUniqueSlice(instruction, {i}));
1307     const Shape& operand_shape = instruction->operand(i)->shape();
1308     CHECK(operand_shape.IsArray())
1309         << "Operands to all-to-all must be arrays: " << instruction->ToString();
1310     output_buffer_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
1311     input_buffer_ptrs.push_back(GetEmittedValueFor(op));
1312     CHECK(buffer_size == -1 || buffer_size == out_slice.size());
1313     buffer_size = out_slice.size();
1314   }
1315 
1316   llvm::Value* input_buffers =
1317       EncodeArrayFunctionArguments(input_buffer_ptrs, "input_buffers", &b_);
1318   llvm::Value* output_buffers =
1319       EncodeArrayFunctionArguments(output_buffer_ptrs, "output_buffers", &b_);
1320 
1321   EmitCallToFunc(
1322       runtime::kAllToAllSymbolName,
1323       {/*run_options=*/GetExecutableRunOptionsArgument(),
1324        /*channel_id_present=*/
1325        b_.getInt32(static_cast<int32_t>(instruction->channel_id().has_value())),
1326        /*op_id=*/
1327        b_.getInt64(instruction->channel_id().has_value()
1328                        ? *instruction->channel_id()
1329                        : instruction->GetModule()->unique_id()),
1330        /*replica_groups=*/replica_groups_v,
1331        /*replica_groups_size=*/b_.getInt32(replica_groups_size),
1332        /*num_buffers=*/b_.getInt32(instruction->operand_count()),
1333        /*buffer_size=*/b_.getInt64(buffer_size),
1334        /*source_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type),
1335        /*destination_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)},
1336       b_.getVoidTy());
1337 
1338   llvm_ir::EmitTuple(GetIrArrayFor(instruction), output_buffer_ptrs, &b_);
1339   return OkStatus();
1340 }
1341 
HandleCollectivePermute(HloInstruction * crs)1342 Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) {
1343   auto* instr = Cast<HloCollectivePermuteInstruction>(crs);
1344   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instr));
1345   std::string source_target_pairs = absl::StrJoin(
1346       instr->source_target_pairs(), ",", absl::PairFormatter("="));
1347   llvm::Value* source_target_pairs_v =
1348       b_.CreateGlobalStringPtr(source_target_pairs);
1349 
1350   Shape shape = crs->operand(0)->shape();
1351 
1352   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
1353                       assignment_.GetUniqueSlice(crs->operand(0), {}));
1354   llvm::Value* input_buffer = EmitBufferPointer(input_slice, shape);
1355 
1356   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1357                       assignment_.GetUniqueSlice(crs, {}));
1358   llvm::Value* output_buffer = EmitBufferPointer(output_slice, shape);
1359 
1360   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1361   EmitCallToFunc(
1362       runtime::kCollectivePermuteSymbolName,
1363       {/*run_options=*/GetExecutableRunOptionsArgument(),
1364        /*channel_id_present=*/
1365        b_.getInt32(static_cast<int32_t>(crs->channel_id().has_value())),
1366        /*op_id=*/
1367        b_.getInt64(crs->channel_id().has_value()
1368                        ? *crs->channel_id()
1369                        : crs->GetModule()->unique_id()),
1370        /*byte_size=*/b_.getInt32(ShapeUtil::ByteSizeOf(shape)),
1371        /*input_buffer=*/b_.CreateBitCast(input_buffer, i8_ptr_type),
1372        /*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type),
1373        /*source_target_pairs=*/source_target_pairs_v,
1374        /*source_target_pairs_size=*/b_.getInt32(source_target_pairs.size())},
1375       b_.getVoidTy());
1376 
1377   return OkStatus();
1378 }
1379 
HandlePartitionId(HloInstruction * hlo)1380 Status IrEmitter::HandlePartitionId(HloInstruction* hlo) {
1381   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
1382   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1383                       assignment_.GetUniqueSlice(hlo, {}));
1384   llvm::Value* output_buffer = EmitBufferPointer(output_slice, hlo->shape());
1385   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1386   EmitCallToFunc(
1387       runtime::kPartitionIdSymbolName,
1388       {/*run_options=*/GetExecutableRunOptionsArgument(),
1389        /*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type)},
1390       b_.getVoidTy());
1391   return OkStatus();
1392 }
1393 
HandleReplicaId(HloInstruction * hlo)1394 Status IrEmitter::HandleReplicaId(HloInstruction* hlo) {
1395   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
1396   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1397                       assignment_.GetUniqueSlice(hlo, {}));
1398   llvm::Value* output_buffer = EmitBufferPointer(output_slice, hlo->shape());
1399   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1400   EmitCallToFunc(
1401       runtime::kReplicaIdSymbolName,
1402       {/*run_options=*/GetExecutableRunOptionsArgument(),
1403        /*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type)},
1404       b_.getVoidTy());
1405   return OkStatus();
1406 }
1407 
HandleParameter(HloInstruction * parameter)1408 Status IrEmitter::HandleParameter(HloInstruction* parameter) {
1409   VLOG(2) << "HandleParameter: " << parameter->ToString();
1410   return EmitTargetAddressForOp(parameter);
1411 }
1412 
1413 // Returns true if the relative order of the unreduced dimensions stays the same
1414 // through the reduce operation.
ReductionPreservesLayout(const HloInstruction & reduce)1415 static bool ReductionPreservesLayout(const HloInstruction& reduce) {
1416   DCHECK_EQ(reduce.opcode(), HloOpcode::kReduce);
1417 
1418   // Maps dimensions that were not reduced from their dimension numbers in the
1419   // source shape to their dimensions numbers in the destination shape.
1420   //
1421   // So if we reduce f32[A,B,C,D] on dimensions 1 and 2, this map contains
1422   // [0->0, 3->1].
1423   absl::flat_hash_map<int64_t, int64_t> unreduced_dim_map;
1424 
1425   absl::flat_hash_set<int64_t> reduced_dims(reduce.dimensions().begin(),
1426                                             reduce.dimensions().end());
1427 
1428   const Shape& operand_shape = reduce.operand(0)->shape();
1429   const Shape& result_shape = reduce.shape();
1430 
1431   int64_t delta = 0;
1432   for (int64_t i = 0; i < operand_shape.dimensions_size(); i++) {
1433     if (reduced_dims.contains(i)) {
1434       delta++;
1435     } else {
1436       InsertOrDie(&unreduced_dim_map, i, i - delta);
1437     }
1438   }
1439 
1440   // Iterate dimensions minor to major and check that the corresponding
1441   // dimensions in the source and target shapes are equivalent.
1442   int64_t result_dim_idx = 0;
1443   for (int64_t operand_dim_idx = 0;
1444        operand_dim_idx < operand_shape.dimensions_size(); operand_dim_idx++) {
1445     int64_t operand_dim =
1446         operand_shape.layout().minor_to_major(operand_dim_idx);
1447     if (!reduced_dims.contains(operand_dim)) {
1448       if (FindOrDie(unreduced_dim_map, operand_dim) !=
1449           result_shape.layout().minor_to_major(result_dim_idx++)) {
1450         return false;
1451       }
1452     }
1453   }
1454 
1455   CHECK_EQ(result_dim_idx, result_shape.dimensions_size());
1456 
1457   return true;
1458 }
1459 
MatchReductionGenerator(HloComputation * function,std::string * failure_reason) const1460 IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator(
1461     HloComputation* function, std::string* failure_reason) const {
1462   CHECK_EQ(function->num_parameters(), 2);
1463 
1464   auto root_instruction = function->root_instruction();
1465   CHECK(ShapeUtil::IsScalar(root_instruction->shape()));
1466 
1467   if (root_instruction->operand_count() != 2) {
1468     *failure_reason = "root instruction is not a binary operation";
1469     return nullptr;
1470   }
1471 
1472   const Shape& root_shape = root_instruction->shape();
1473   if (ShapeUtil::ElementIsComplex(root_shape)) {
1474     // TODO(b/65408531): Complex add could by done via bitcast to <float x [2N]>
1475     // Complex multiply would be more challenging. We could perhaps use a
1476     // strided load to get all reals in a vector, all images in a vector, or use
1477     // CreateShuffleVector on a bitcast to float x [2N].
1478     *failure_reason = "complex values not supported";
1479     return nullptr;
1480   }
1481   bool root_is_floating_point = ShapeUtil::ElementIsFloating(root_shape);
1482   bool root_is_integral = ShapeUtil::ElementIsIntegral(root_shape);
1483   bool root_is_signed = ShapeUtil::ElementIsSigned(root_shape);
1484 
1485   auto lhs = root_instruction->operand(0);
1486   auto rhs = root_instruction->operand(1);
1487 
1488   auto param_0 = function->parameter_instruction(0);
1489   auto param_1 = function->parameter_instruction(1);
1490   if (!(lhs == param_0 && rhs == param_1) &&
1491       !(rhs == param_0 && lhs == param_1)) {
1492     *failure_reason =
1493         "root instruction is not a binary operation on the incoming arguments";
1494     return nullptr;
1495   }
1496 
1497   CHECK(ShapeUtil::IsScalar(lhs->shape()) && ShapeUtil::IsScalar(rhs->shape()));
1498 
1499   // This is visually similar to ElementalIrEmitter, though conceptually we're
1500   // doing something different here.  ElementalIrEmitter emits scalar operations
1501   // while these emit scalar or vector operations depending on the type of the
1502   // operands. See CreateShardedVectorType for the actual types in use here.
1503   switch (root_instruction->opcode()) {
1504     default:
1505       *failure_reason = "did not recognize root instruction opcode";
1506       return nullptr;
1507 
1508     case HloOpcode::kAdd:
1509       return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs,
1510                                 llvm::Value* rhs) {
1511         return root_is_integral ? b->CreateAdd(lhs, rhs)
1512                                 : b->CreateFAdd(lhs, rhs);
1513       };
1514 
1515     case HloOpcode::kMultiply:
1516       return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs,
1517                                 llvm::Value* rhs) {
1518         return root_is_integral ? b->CreateMul(lhs, rhs)
1519                                 : b->CreateFMul(lhs, rhs);
1520       };
1521 
1522     case HloOpcode::kAnd:
1523       return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1524         return b->CreateAnd(lhs, rhs);
1525       };
1526 
1527     case HloOpcode::kOr:
1528       return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1529         return b->CreateOr(lhs, rhs);
1530       };
1531 
1532     case HloOpcode::kXor:
1533       return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1534         return b->CreateXor(lhs, rhs);
1535       };
1536 
1537     case HloOpcode::kMaximum:
1538       return [root_is_floating_point, root_is_signed, this](
1539                  llvm::IRBuilder<>* b, llvm::Value* lhs,
1540                  llvm::Value* rhs) -> llvm::Value* {
1541         if (root_is_floating_point) {
1542           return llvm_ir::EmitFloatMax(
1543               lhs, rhs, b,
1544               hlo_module_config_.debug_options().xla_cpu_enable_fast_min_max());
1545         }
1546 
1547         return b->CreateSelect(
1548             b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SGE
1549                                          : llvm::ICmpInst::ICMP_UGE,
1550                           lhs, rhs),
1551             lhs, rhs);
1552       };
1553 
1554     case HloOpcode::kMinimum:
1555       return [root_is_floating_point, root_is_signed, this](
1556                  llvm::IRBuilder<>* b, llvm::Value* lhs,
1557                  llvm::Value* rhs) -> llvm::Value* {
1558         if (root_is_floating_point) {
1559           return llvm_ir::EmitFloatMin(
1560               lhs, rhs, b,
1561               hlo_module_config_.debug_options().xla_cpu_enable_fast_min_max());
1562         }
1563 
1564         return b->CreateSelect(
1565             b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SLE
1566                                          : llvm::ICmpInst::ICMP_ULE,
1567                           lhs, rhs),
1568             lhs, rhs);
1569       };
1570   }
1571 }
1572 
CreateShardedVectorType(PrimitiveType element_type,unsigned element_count)1573 IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType(
1574     PrimitiveType element_type, unsigned element_count) {
1575   int vector_register_size_in_elements =
1576       target_machine_features_.vector_register_byte_size(
1577           *compute_function_->function()) /
1578       ShapeUtil::ByteSizeOfPrimitiveType(element_type);
1579 
1580   ShardedVectorType sharded_vector_type;
1581   llvm::Type* element_ir_type =
1582       llvm_ir::PrimitiveTypeToIrType(element_type, module_);
1583 
1584   for (int i = 0, e = 1 + Log2Ceiling(element_count); i < e; i++) {
1585     // For every power of two present in element_count, we generate one or more
1586     // vector or scalar types.
1587     const unsigned current_size_fragment = 1u << i;
1588     if (!(element_count & current_size_fragment)) {
1589       // Power of two not present in element_count.
1590       continue;
1591     }
1592 
1593     if (current_size_fragment == 1) {
1594       // Single element, use a scalar type.
1595       sharded_vector_type.push_back(element_ir_type);
1596       continue;
1597     }
1598 
1599     // Lower "current_size_fragment" number of elements using (as few as
1600     // possible) vector registers.
1601 
1602     if (current_size_fragment >= vector_register_size_in_elements) {
1603       auto vector_type = llvm::VectorType::get(
1604           element_ir_type, vector_register_size_in_elements, false);
1605       sharded_vector_type.insert(
1606           sharded_vector_type.end(),
1607           current_size_fragment / vector_register_size_in_elements,
1608           vector_type);
1609 
1610       // Both current_size_fragment and vector_register_size_in_elements are
1611       // powers of two.
1612       CHECK_EQ(current_size_fragment % vector_register_size_in_elements, 0);
1613       continue;
1614     }
1615 
1616     // For now we assume that vector_register_size_in_elements and lower powers
1617     // of two are all legal vector sizes (or at least can be lowered easily by
1618     // LLVM).
1619     sharded_vector_type.push_back(
1620         llvm::VectorType::get(element_ir_type, current_size_fragment, false));
1621   }
1622   return sharded_vector_type;
1623 }
1624 
1625 StatusOr<IrEmitter::ShardedVector>
EmitInnerLoopForVectorizedReduction(const ReductionGenerator & reduction_generator,const llvm_ir::IrArray::Index & output_index,const ShardedVectorType & accumulator_type,HloInstruction * init_value,HloInstruction * arg,absl::Span<const int64_t> dimensions,llvm::Align element_alignment)1626 IrEmitter::EmitInnerLoopForVectorizedReduction(
1627     const ReductionGenerator& reduction_generator,
1628     const llvm_ir::IrArray::Index& output_index,
1629     const ShardedVectorType& accumulator_type, HloInstruction* init_value,
1630     HloInstruction* arg, absl::Span<const int64_t> dimensions,
1631     llvm::Align element_alignment) {
1632   ShardedVector accumulator;
1633   accumulator.reserve(accumulator_type.size());
1634   for (auto accumulator_shard_type : accumulator_type) {
1635     accumulator.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
1636         accumulator_shard_type, "accumulator", &b_, 0));
1637   }
1638 
1639   llvm::Value* init_value_ssa =
1640       Load(IrShapeType(init_value->shape()), GetEmittedValueFor(init_value));
1641 
1642   for (llvm::Value* accumulator_shard : accumulator) {
1643     llvm::Value* initial_value;
1644     auto shard_type =
1645         llvm::cast<llvm::AllocaInst>(accumulator_shard)->getAllocatedType();
1646     if (auto vector_type = llvm::dyn_cast<llvm::VectorType>(shard_type)) {
1647       initial_value =
1648           VectorSplat(vector_type->getElementCount(), init_value_ssa);
1649     } else {
1650       initial_value = init_value_ssa;
1651     }
1652 
1653     AlignedStore(initial_value, accumulator_shard, element_alignment);
1654   }
1655 
1656   llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"),
1657                                            &b_);
1658   std::vector<llvm::Value*> input_multi_index =
1659       reduction_loop_nest.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
1660                                                        "reduction_dim");
1661 
1662   SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(), &b_);
1663 
1664   llvm_ir::IrArray arg_array(GetIrArrayFor(arg));
1665   llvm_ir::IrArray::Index::const_iterator it = output_index.begin();
1666 
1667   for (auto& i : input_multi_index) {
1668     if (i == nullptr) {
1669       i = *it++;
1670     }
1671   }
1672   CHECK(output_index.end() == it);
1673   llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
1674                                       b_.getInt64Ty());
1675 
1676   llvm::Value* input_address = BitCast(
1677       arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy());
1678 
1679   for (int i = 0; i < accumulator.size(); i++) {
1680     auto input_address_typed =
1681         BitCast(input_address, accumulator[i]->getType());
1682     auto alloca = llvm::cast<llvm::AllocaInst>(accumulator[i]);
1683     auto current_accumulator_value = AlignedLoad(
1684         alloca->getAllocatedType(), accumulator[i], element_alignment);
1685     auto addend = AlignedLoad(alloca->getAllocatedType(), input_address_typed,
1686                               element_alignment);
1687     arg_array.AnnotateLoadStoreInstructionWithMetadata(addend);
1688 
1689     auto reduced_result =
1690         reduction_generator(&b_, current_accumulator_value, addend);
1691     AlignedStore(reduced_result, accumulator[i], element_alignment);
1692 
1693     if (i != (accumulator.size() - 1)) {
1694       input_address = ConstInBoundsGEP1_32(reduced_result->getType(),
1695                                            input_address_typed, 1);
1696     }
1697   }
1698 
1699   SetToFirstInsertPoint(reduction_loop_nest.GetOuterLoopExitBasicBlock(), &b_);
1700 
1701   ShardedVector result_ssa;
1702   result_ssa.reserve(accumulator.size());
1703   for (auto accumulator_shard : accumulator) {
1704     auto alloca = llvm::cast<llvm::AllocaInst>(accumulator_shard);
1705     result_ssa.push_back(AlignedLoad(alloca->getAllocatedType(),
1706                                      accumulator_shard, element_alignment));
1707   }
1708   return result_ssa;
1709 }
1710 
EmitShardedVectorStore(llvm::Value * store_address,const std::vector<llvm::Value * > & value_to_store,llvm::Align alignment,const llvm_ir::IrArray & containing_array)1711 void IrEmitter::EmitShardedVectorStore(
1712     llvm::Value* store_address, const std::vector<llvm::Value*>& value_to_store,
1713     llvm::Align alignment, const llvm_ir::IrArray& containing_array) {
1714   for (int i = 0; i < value_to_store.size(); i++) {
1715     auto store_address_typed =
1716         BitCast(store_address,
1717                 llvm::PointerType::getUnqual(value_to_store[i]->getType()));
1718 
1719     auto store_instruction =
1720         AlignedStore(value_to_store[i], store_address_typed, alignment);
1721     containing_array.AnnotateLoadStoreInstructionWithMetadata(
1722         store_instruction);
1723 
1724     if (i != (value_to_store.size() - 1)) {
1725       store_address = ConstInBoundsGEP1_32(value_to_store[i]->getType(),
1726                                            store_address_typed, 1);
1727     }
1728   }
1729 }
1730 
EmitVectorizedReduce(HloInstruction * reduce,HloInstruction * arg,HloInstruction * init_value,absl::Span<const int64_t> dimensions,HloComputation * function,std::string * failure_reason)1731 StatusOr<bool> IrEmitter::EmitVectorizedReduce(
1732     HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
1733     absl::Span<const int64_t> dimensions, HloComputation* function,
1734     std::string* failure_reason) {
1735   if (!reduce->shape().IsArray()) {
1736     *failure_reason = "vectorization of variadic reduce not implemented";
1737     return false;
1738   }
1739 
1740   if (!ReductionPreservesLayout(*reduce)) {
1741     return false;
1742   }
1743 
1744   ReductionGenerator reduction_generator =
1745       MatchReductionGenerator(function, failure_reason);
1746   if (!reduction_generator) {
1747     return false;
1748   }
1749 
1750   int vector_register_size_in_elements =
1751       target_machine_features_.vector_register_byte_size(
1752           *compute_function_->function()) /
1753       ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type());
1754   if (vector_register_size_in_elements == 0) {
1755     // Either we don't know the vector register width for the target or the
1756     // vector register is smaller than the size of the primitive type.
1757     return false;
1758   }
1759 
1760   int vectorization_factor_in_bytes =
1761       target_machine_features_.vectorization_factor_in_bytes();
1762 
1763   // We try to process vectorization_factor elements at the same time.
1764   const int vectorization_factor =
1765       vectorization_factor_in_bytes /
1766       ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type());
1767 
1768   bool is_reduction_over_minor_dimension = absl::c_linear_search(
1769       dimensions, LayoutUtil::Minor(arg->shape().layout(), 0));
1770 
1771   llvm::Align element_alignment(tensorflow::MathUtil::GCD<unsigned>(
1772       ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()),
1773       MinimumAlignmentForPrimitiveType(reduce->shape().element_type())));
1774 
1775   if (is_reduction_over_minor_dimension) {
1776     // TODO(sanjoy): Implement vectorized reduction over the minor dimension.
1777     *failure_reason = "reduction over minor dimension not implemented";
1778     return false;
1779   }
1780 
1781   CHECK(!reduce->shape().IsTuple());
1782   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(reduce));
1783 
1784   // We know we're not reducing over the most minor dimension, which means we
1785   // can lower the reduction loop as:
1786   //
1787   //  1. We're reducing over dimensions R0, R1.
1788   //  2. D0 is the most minor dimension.
1789   //  3. VS is the vectorization stride (we want to reduce this many elements at
1790   //     once)
1791   //
1792   //  for (d1 in D1) {
1793   //    for (d0 in D0 with stride VS) {
1794   //      vector_acc = init
1795   //      for (r1 in R1) {
1796   //        for (r0 in R0) {
1797   //          vector_acc = elementwise_reduce(vector_acc, input[d1, d0, r1, r0]
1798   //        }
1799   //      }
1800   //      output[d1, d0] = vector_acc
1801   //    }
1802   //  }
1803 
1804   llvm_ir::ForLoopNest loop_nest(IrName(reduce), &b_);
1805   std::vector<llvm::Value*> array_multi_index(
1806       reduce->shape().dimensions_size());
1807   for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0;
1808        --i) {
1809     int64_t dimension = LayoutUtil::Minor(reduce->shape().layout(), i);
1810     int64_t start_index = 0;
1811     int64_t end_index = reduce->shape().dimensions(dimension);
1812     std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop(
1813         start_index, end_index, absl::StrFormat("dim.%d", dimension));
1814     array_multi_index[dimension] = loop->GetIndVarValue();
1815   }
1816 
1817   int64_t innermost_dimension = LayoutUtil::Minor(reduce->shape().layout(), 0);
1818   int64_t innermost_dimension_size =
1819       reduce->shape().dimensions(innermost_dimension);
1820 
1821   if (llvm::BasicBlock* innermost_body_bb =
1822           loop_nest.GetInnerLoopBodyBasicBlock()) {
1823     SetToFirstInsertPoint(innermost_body_bb, &b_);
1824   }
1825 
1826   auto outermost_loop_exit_block = loop_nest.GetOuterLoopExitBasicBlock();
1827 
1828   if (innermost_dimension_size >= vectorization_factor) {
1829     int64_t start_index = 0;
1830     int64_t end_index = (innermost_dimension_size / vectorization_factor) *
1831                         vectorization_factor;
1832     std::unique_ptr<llvm_ir::ForLoop> loop =
1833         loop_nest.AddLoop(start_index, end_index, vectorization_factor,
1834                           absl::StrFormat("dim.%d", innermost_dimension));
1835     array_multi_index[innermost_dimension] = loop->GetIndVarValue();
1836 
1837     SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &b_);
1838 
1839     ShardedVectorType vector_type = CreateShardedVectorType(
1840         reduce->shape().element_type(), vectorization_factor);
1841     llvm_ir::IrArray::Index array_index(array_multi_index, reduce->shape(),
1842                                         b_.getInt64Ty());
1843     TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accumulator,
1844                         EmitInnerLoopForVectorizedReduction(
1845                             reduction_generator, array_index, vector_type,
1846                             init_value, arg, dimensions, element_alignment));
1847 
1848     llvm_ir::IrArray target_array = GetIrArrayFor(reduce);
1849     llvm::Value* output_address =
1850         target_array.EmitArrayElementAddress(array_index, &b_);
1851     EmitShardedVectorStore(output_address, accumulator, element_alignment,
1852                            target_array);
1853 
1854     if (auto exit_terminator = loop->GetExitBasicBlock()->getTerminator()) {
1855       CHECK_GT(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1);
1856       b_.SetInsertPoint(exit_terminator);
1857     } else {
1858       CHECK_EQ(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1);
1859       b_.SetInsertPoint(loop->GetExitBasicBlock());
1860     }
1861   }
1862 
1863   // Since we increment the stride for the inner dimension by more than 1, we
1864   // may need to peel out an "epilogue" iteration to get the remaining elements
1865   // in the following case:
1866   if (innermost_dimension_size % vectorization_factor) {
1867     // TODO(b/63775531): Consider using a scalar loop here to save on code size.
1868     array_multi_index[innermost_dimension] =
1869         b_.getInt64(innermost_dimension_size -
1870                     (innermost_dimension_size % vectorization_factor));
1871 
1872     ShardedVectorType vector_type = CreateShardedVectorType(
1873         reduce->shape().element_type(),
1874         innermost_dimension_size % vectorization_factor);
1875     llvm_ir::IrArray::Index array_index(array_multi_index, reduce->shape(),
1876                                         b_.getInt64Ty());
1877     llvm::IRBuilderBase::FastMathFlagGuard guard(b_);
1878     llvm::FastMathFlags flags = b_.getFastMathFlags();
1879     flags.setAllowReassoc(true);
1880     b_.setFastMathFlags(flags);
1881     TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accumulator,
1882                         EmitInnerLoopForVectorizedReduction(
1883                             reduction_generator, array_index, vector_type,
1884                             init_value, arg, dimensions, element_alignment));
1885 
1886     llvm_ir::IrArray target_array = GetIrArrayFor(reduce);
1887     llvm::Value* output_address =
1888         target_array.EmitArrayElementAddress(array_index, &b_);
1889     EmitShardedVectorStore(output_address, accumulator, element_alignment,
1890                            target_array);
1891   }
1892 
1893   if (outermost_loop_exit_block) {
1894     b_.SetInsertPoint(outermost_loop_exit_block);
1895   }
1896 
1897   return true;
1898 }
1899 
HandleReduce(HloInstruction * reduce)1900 Status IrEmitter::HandleReduce(HloInstruction* reduce) {
1901   auto arg = reduce->mutable_operand(0);
1902   auto init_value = reduce->mutable_operand(1);
1903   absl::Span<const int64_t> dimensions(reduce->dimensions());
1904   HloComputation* function = reduce->to_apply();
1905   bool saved_allow_reassociation = allow_reassociation_;
1906   allow_reassociation_ = true;
1907   auto cleanup = absl::MakeCleanup([saved_allow_reassociation, this]() {
1908     allow_reassociation_ = saved_allow_reassociation;
1909   });
1910   if (!options::VectorizedReduceDisabled(hlo_module_config_)) {
1911     std::string vectorization_failure_reason;
1912     TF_ASSIGN_OR_RETURN(
1913         bool vectorization_successful,
1914         EmitVectorizedReduce(reduce, arg, init_value, dimensions, function,
1915                              &vectorization_failure_reason));
1916     if (vectorization_successful) {
1917       VLOG(1) << "Successfully vectorized reduction " << reduce->ToString()
1918               << "\n";
1919       return OkStatus();
1920     } else {
1921       VLOG(1) << "Could not vectorize reduction " << reduce->ToString() << ": "
1922               << vectorization_failure_reason;
1923     }
1924   }
1925 
1926   return DefaultAction(reduce);
1927 }
1928 
HandleSend(HloInstruction * send)1929 Status IrEmitter::HandleSend(HloInstruction* send) {
1930   // TODO(b/33942983): Support Send/Recv on CPU.
1931   return Unimplemented("Send is not implemented on CPU.");
1932 }
1933 
HandleSendDone(HloInstruction * send_done)1934 Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
1935   // TODO(b/33942983): Support Send/Recv on CPU.
1936   return Unimplemented("Send-done is not implemented on CPU.");
1937 }
1938 
HandleScatter(HloInstruction *)1939 Status IrEmitter::HandleScatter(HloInstruction*) {
1940   return Unimplemented("Scatter is not implemented on CPUs.");
1941 }
1942 
HandleSlice(HloInstruction * slice)1943 Status IrEmitter::HandleSlice(HloInstruction* slice) {
1944   VLOG(2) << "HandleSlice: " << slice->ToString();
1945   auto operand = slice->operand(0);
1946   // The code below emits a sequential loop nest. For the parallel backend, use
1947   // ParallelLoopEmitter which respects dynamic loop bounds.
1948   if (ShouldEmitParallelLoopFor(*slice)) {
1949     return DefaultAction(slice);
1950   }
1951 
1952   // The code below assumes the layouts are equal.
1953   if (!LayoutUtil::Equal(operand->shape().layout(), slice->shape().layout())) {
1954     return DefaultAction(slice);
1955   }
1956 
1957   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice));
1958 
1959   if (ShapeUtil::IsZeroElementArray(slice->shape())) {
1960     return OkStatus();
1961   }
1962 
1963   const Layout& layout = operand->shape().layout();
1964   const int64_t num_dims = operand->shape().dimensions_size();
1965 
1966   // The slice lowering finds maximal contiguous blocks of memory that can be
1967   // copied from the source to the target. This is done by looking at the
1968   // source/target layout in minor to major order and do the following:
1969   //
1970   // * Find an initial segment of dimensions along which the slice uses the
1971   //   whole dimension. These are the "inner" dimensions and can be folded into
1972   //   the memcpy.
1973   //
1974   // * Of the remaining dimensions decide which ones require loops.
1975   //
1976   // * Implement the memcpy within the innermost loop.
1977 
1978   absl::flat_hash_set<int64_t> inner_dims;
1979   for (int64_t dim : LayoutUtil::MinorToMajor(layout)) {
1980     if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) {
1981       break;
1982     }
1983     inner_dims.insert(dim);
1984   }
1985 
1986   const bool is_trivial_copy = (inner_dims.size() == num_dims);
1987   if (is_trivial_copy) {
1988     if (ShapeUtil::IsEffectiveScalar(slice->shape())) {
1989       return DefaultAction(slice);
1990     } else {
1991       return EmitMemcpy(*slice, *operand);
1992     }
1993   }
1994 
1995   // The memcpy will copy elements that are logically this shape (allowed to be
1996   // scalar).
1997   const Shape logical_element_shape = ShapeUtil::FilterDimensions(
1998       [&inner_dims](int64_t dim) { return inner_dims.contains(dim); },
1999       operand->shape());
2000 
2001   const int64_t primitive_elements_per_logical_element =
2002       ShapeUtil::ElementsIn(logical_element_shape);
2003 
2004   // memcpy_dim is the innermost (in terms of layout) dimension for which the
2005   // slice does *not* just copy all the elements along the dimension.
2006   const int64_t memcpy_dim = LayoutUtil::Minor(layout, inner_dims.size());
2007 
2008   const bool memcpy_is_contiguous = slice->slice_strides(memcpy_dim) == 1;
2009   // The number of logical elements that can be copied in a single call
2010   // to memcpy. We can only copy 1 element at a time if there is a non-trivial
2011   // stride.
2012   const int64_t memcpy_logical_elements =
2013       memcpy_is_contiguous
2014           ? slice->slice_limits(memcpy_dim) - slice->slice_starts(memcpy_dim)
2015           : 1;
2016 
2017   // Determine the dimensions that get lowered as loops.
2018   llvm::SmallVector<int64_t> outer_dims;
2019   for (int64_t i = 0; i < num_dims - inner_dims.size() - 1; ++i) {
2020     outer_dims.push_back(LayoutUtil::Major(layout, i));
2021   }
2022 
2023   // Is the slice along the memcpy dimension contiguous? If not, then memcpy_dim
2024   // needs to be wrapped around a loop as well.
2025   if (!memcpy_is_contiguous) {
2026     outer_dims.push_back(memcpy_dim);
2027   }
2028 
2029   llvm_ir::IrArray target_array = GetIrArrayFor(slice);
2030 
2031   const int64_t num_outer_loops = outer_dims.size();
2032   llvm_ir::ForLoopNest loops(IrName(slice), &b_);
2033   std::vector<llvm::Value*> target_multi_index =
2034       loops.AddLoopsForShapeOnDimensions(slice->shape(), outer_dims, "slice");
2035 
2036   // Only the indices for the outer dimensions have been initialized in
2037   // target_index. The rest of the indices should get initialized to 0, since
2038   // for the rest of the dimensions the copy writes to the full dimension.
2039   std::replace(target_multi_index.begin(), target_multi_index.end(),
2040                static_cast<llvm::Value*>(nullptr),
2041                static_cast<llvm::Value*>(b_.getInt64(0)));
2042   llvm_ir::IrArray::Index target_index(target_multi_index, slice->shape(),
2043                                        b_.getInt64Ty());
2044 
2045   if (num_outer_loops > 0) {
2046     SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
2047   }
2048 
2049   llvm_ir::IrArray source_array = GetIrArrayFor(operand);
2050   const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice(
2051       /*operand_shape=*/operand->shape(), /*starts=*/slice->slice_starts(),
2052       /*strides=*/slice->slice_strides(), /*builder=*/&b_);
2053 
2054   llvm::Value* memcpy_dest =
2055       target_array.EmitArrayElementAddress(target_index, &b_, "slice.dest");
2056   llvm::Value* memcpy_source =
2057       source_array.EmitArrayElementAddress(source_index, &b_, "slice.source");
2058 
2059   const int64_t memcpy_elements =
2060       primitive_elements_per_logical_element * memcpy_logical_elements;
2061 
2062   EmitTransferElements(memcpy_dest, memcpy_source, memcpy_elements,
2063                        slice->shape().element_type(), target_array,
2064                        source_array);
2065 
2066   if (VLOG_IS_ON(2)) {
2067     const int64_t memcpy_bytes =
2068         ShapeUtil::ByteSizeOf(logical_element_shape) * memcpy_elements;
2069     VLOG(2) << "  emitted copy of " << memcpy_bytes << " bytes inside "
2070             << num_outer_loops << " loops";
2071   }
2072 
2073   if (num_outer_loops > 0) {
2074     SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
2075   }
2076 
2077   return OkStatus();
2078 }
2079 
HandleDynamicSlice(HloInstruction * dynamic_slice)2080 Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice) {
2081   if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
2082     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_slice));
2083     return EmitMemcpy(*dynamic_slice->operand(0), *dynamic_slice);
2084   }
2085   return DefaultAction(dynamic_slice);
2086 }
2087 
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)2088 Status IrEmitter::HandleDynamicUpdateSlice(
2089     HloInstruction* dynamic_update_slice) {
2090   auto update = dynamic_update_slice->operand(1);
2091   if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) {
2092     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice));
2093     return EmitMemcpy(*update, *dynamic_update_slice);
2094   } else if (llvm_ir::CanUpdateDynamicSliceInPlace(dynamic_update_slice,
2095                                                    assignment_)) {
2096     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice));
2097     auto operands = GetIrArraysForOperandsOf(dynamic_update_slice);
2098     return llvm_ir::EmitDynamicUpdateSliceInPlace(
2099         operands, GetIrArrayFor(dynamic_update_slice),
2100         IrName(dynamic_update_slice, "in_place"), &b_);
2101   }
2102   return DefaultAction(dynamic_update_slice);
2103 }
2104 
HandleRecv(HloInstruction * recv)2105 Status IrEmitter::HandleRecv(HloInstruction* recv) {
2106   // TODO(b/33942983): Support Send/Recv on CPU.
2107   return Unimplemented("Recv is not implemented on CPU.");
2108 }
2109 
HandleRecvDone(HloInstruction * recv_done)2110 Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) {
2111   // TODO(b/33942983): Support Send/Recv on CPU.
2112   return Unimplemented("Recv-done is not implemented on CPU.");
2113 }
2114 
HandlePad(HloInstruction * pad)2115 Status IrEmitter::HandlePad(HloInstruction* pad) {
2116   // CPU backend does not properly handle negative padding but this is ok
2117   // because negative padding should be removed by the algebraic simplifier.
2118   for (auto& padding_dimension : pad->padding_config().dimensions()) {
2119     if (padding_dimension.edge_padding_low() < 0 ||
2120         padding_dimension.edge_padding_high() < 0) {
2121       return InternalErrorStrCat(
2122           "Encountered negative padding in IrEmitter on CPU. "
2123           "This should have been eliminated at the HLO level. ",
2124           pad->ToString());
2125     }
2126   }
2127 
2128   // First, fill in the padding value to all output elements.
2129   TF_RETURN_IF_ERROR(EmitTargetElementLoop(
2130       pad, "initialize",
2131       [this, pad](const llvm_ir::IrArray::Index& target_index) {
2132         const HloInstruction* padding_value = pad->operand(1);
2133         llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value);
2134         return Load(IrShapeType(padding_value->shape()), padding_value_addr);
2135       }));
2136 
2137   // Create a loop to iterate over the operand elements and update the output
2138   // locations where the operand elements should be stored.
2139   llvm_ir::ForLoopNest loops(IrName(pad, "assign"), &b_);
2140   const HloInstruction* operand = pad->operand(0);
2141   const llvm_ir::IrArray::Index operand_index =
2142       loops.AddLoopsForShape(operand->shape(), "operand");
2143 
2144   SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
2145 
2146   // Load an element from the operand.
2147   llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
2148   llvm::Value* operand_data =
2149       operand_array.EmitReadArrayElement(operand_index, &b_);
2150 
2151   // Compute the output index the operand element should be assigned to.
2152   // output_index := edge_padding_low + operand_index * (interior_padding + 1)
2153   const PaddingConfig& padding_config = pad->padding_config();
2154   std::vector<llvm::Value*> output_multi_index;
2155   for (size_t i = 0; i < operand_index.size(); ++i) {
2156     llvm::Value* offset =
2157         Mul(operand_index[i],
2158             b_.getInt64(padding_config.dimensions(i).interior_padding() + 1));
2159     llvm::Value* index = Add(
2160         offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low()));
2161     output_multi_index.push_back(index);
2162   }
2163 
2164   // Store the operand element to the computed output location.
2165   llvm_ir::IrArray output_array(GetIrArrayFor(pad));
2166   llvm_ir::IrArray::Index output_index(
2167       output_multi_index, output_array.GetShape(), operand_index.GetType());
2168   output_array.EmitWriteArrayElement(output_index, operand_data, &b_);
2169 
2170   SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
2171   return OkStatus();
2172 }
2173 
HandleFusion(HloInstruction * fusion)2174 Status IrEmitter::HandleFusion(HloInstruction* fusion) {
2175   auto* root = fusion->fused_expression_root();
2176   if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) {
2177     VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace";
2178     CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
2179     FusedIrEmitter fused_emitter(elemental_emitter);
2180     BindFusionArguments(fusion, &fused_emitter);
2181 
2182     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
2183     // Delegate to common implementation of fused in-place dynamic-update-slice.
2184     return llvm_ir::EmitFusedDynamicUpdateSliceInPlace(
2185         fusion, GetIrArrayFor(fusion), &fused_emitter, &b_);
2186   } else if (fusion->IsLoopFusion()) {
2187     VLOG(3) << "HandleFusion kLoop";
2188     CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
2189     FusedIrEmitter fused_emitter(elemental_emitter);
2190     BindFusionArguments(fusion, &fused_emitter);
2191     TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator(
2192                                             *fusion->fused_expression_root()));
2193     return EmitTargetElementLoop(fusion, generator);
2194   } else if (fusion->IsOutputFusion()) {
2195     VLOG(3) << "HandleFusion kOutput";
2196     int64_t dot_op_index =
2197         root->operand(0)->opcode() == HloOpcode::kDot ? 0 : 1;
2198     const HloInstruction* dot = root->operand(dot_op_index);
2199     CHECK_EQ(dot->opcode(), HloOpcode::kDot)
2200         << dot->ToString() << "  "
2201         << fusion->fused_instructions_computation()->ToString();
2202 
2203     int64_t dot_lhs_param_number = dot->operand(0)->parameter_number();
2204     int64_t dot_rhs_param_number = dot->operand(1)->parameter_number();
2205     int64_t addend_param_number =
2206         root->operand(1 - dot_op_index)->parameter_number();
2207 
2208     Shape target_shape = fusion->shape();
2209     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
2210     llvm_ir::IrArray target_array = GetIrArrayFor(fusion);
2211 
2212     llvm_ir::IrArray lhs_array(
2213         GetIrArrayFor(fusion->operand(dot_lhs_param_number)));
2214     llvm_ir::IrArray rhs_array(
2215         GetIrArrayFor(fusion->operand(dot_rhs_param_number)));
2216     llvm_ir::IrArray addend_array(
2217         GetIrArrayFor(fusion->operand(addend_param_number)));
2218 
2219     TF_RETURN_IF_ERROR(EmitDotOperation(
2220         *dot, target_array, lhs_array, rhs_array, &addend_array,
2221         GetExecutableRunOptionsArgument(), &b_, mlir_context_,
2222         hlo_module_config_, target_machine_features_));
2223     return OkStatus();
2224   } else {
2225     return Unimplemented("Fusion kind not implemented on CPU");
2226   }
2227 }
2228 
HandleCall(HloInstruction * call)2229 Status IrEmitter::HandleCall(HloInstruction* call) {
2230   HloComputation* computation = call->to_apply();
2231   llvm::Function* call_ir_function = FindOrDie(
2232       emitted_functions_, ComputationToEmit{computation, allow_reassociation_});
2233 
2234   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call));
2235 
2236   if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
2237     // Having a nonempty set of 'outer_dimension_partitions' means that this
2238     // computation has been specially selected to be parallelized (one where the
2239     // root instruction is trivially parallelizable, like elementwise addition
2240     // of two tensors). The LLVM function generated for this computation accepts
2241     // an additional set of loop bounds, allowing the caller to control the
2242     // subset of the output that is generated by each call.
2243 
2244     std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments(
2245         {}, &b_, computation->name(),
2246         /*return_value_buffer=*/emitted_value_[call],
2247         /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
2248         /*buffer_table_arg=*/GetBufferTableArgument(),
2249         /*status_arg=*/GetStatusArgument(),
2250         /*profile_counters_arg=*/GetProfileCountersArgument());
2251 
2252     // The parallel fork/join runtime will call the generated function once for
2253     // each partition in parallel, using an appropriate set of loop bounds for
2254     // each call such that it only generates one partition of the output.
2255     HloInstruction* root = computation->root_instruction();
2256     TF_RETURN_IF_ERROR(EmitCallToParallelForkJoin(
2257         call_args, root->shape(), root->outer_dimension_partitions(), &b_,
2258         call_ir_function, computation->name()));
2259 
2260     if (ComputationTransitivelyContainsCustomCall(computation)) {
2261       EmitEarlyReturnIfErrorStatus();
2262     }
2263   } else {
2264     EmitGlobalCall(*computation, computation->name());
2265   }
2266 
2267   return OkStatus();
2268 }
2269 
HandleSliceToDynamic(HloInstruction * hlo)2270 Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) {
2271   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
2272   std::vector<llvm::Value*> dynamic_dims;
2273   int32_t raw_data_size =
2274       ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(hlo->shape()));
2275   llvm::Value* dest_buffer = GetEmittedValueFor(hlo);
2276   llvm::Value* raw_buffer =
2277       b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo());
2278   for (int64_t i = 1; i < hlo->operand_count(); ++i) {
2279     const int64_t dim_index = i - 1;
2280     llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(i));
2281     llvm::LoadInst* dyn_dim_size = Load(IrShapeType(hlo->operand(i)->shape()),
2282                                         source_buffer, "dyn_dim_size");
2283 
2284     llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
2285         b_.getInt8Ty(), raw_buffer,
2286         raw_data_size + dim_index * sizeof(int32_t));
2287     b_.CreateStore(dyn_dim_size,
2288                    b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()));
2289     dynamic_dims.push_back(b_.CreateIntCast(dyn_dim_size, b_.getInt64Ty(),
2290                                             /*isSigned=*/true,
2291                                             "i64_dyn_dim_size"));
2292   }
2293 
2294   llvm_ir::IrArray data_array = GetIrArrayFor(hlo);
2295   // Pseudo code for sliceToDynamic:
2296   //
2297   //   for (index i in dynamic_dim)
2298   //     dest_index = delinearize(linearize(i, dynamic_dim), static_dim)
2299   //     dest[dest_index] = source[i]
2300   auto loop_body_emitter =
2301       [&](const llvm_ir::IrArray::Index& array_index) -> Status {
2302     llvm::Value* source_element =
2303         GetIrArrayFor(hlo->operand(0)).EmitReadArrayElement(array_index, &b_);
2304     llvm::Value* linear_index = array_index.Linearize(dynamic_dims, &b_);
2305     // Delinearize the index based on the static shape.
2306     llvm_ir::IrArray::Index dest_index(linear_index, data_array.GetShape(),
2307                                        &b_);
2308     data_array.EmitWriteArrayElement(dest_index, source_element, &b_);
2309     return OkStatus();
2310   };
2311   return llvm_ir::LoopEmitter(loop_body_emitter, data_array.GetShape(),
2312                               dynamic_dims, &b_)
2313       .EmitLoop(IrName(hlo));
2314 }
2315 
HandlePadToStatic(HloInstruction * hlo)2316 Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) {
2317   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
2318 
2319   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
2320                       assignment_.GetUniqueSlice(hlo, {0}));
2321   std::vector<llvm::Value*> dynamic_dims;
2322   std::vector<llvm::Value*> tuple_operand_ptrs;
2323   const Shape& data_shape = ShapeUtil::GetSubshape(hlo->shape(), {0});
2324   const Shape& input_shape = hlo->operand(0)->shape();
2325   llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape);
2326   llvm::Type* data_type = IrShapeType(data_shape);
2327   llvm_ir::IrArray data_array(data_address, data_type, data_shape);
2328   llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(0));
2329   llvm::Value* raw_buffer =
2330       b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo());
2331   int64_t raw_data_size =
2332       ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(input_shape));
2333 
2334   // Put a placeholder for the data array's pointer
2335   tuple_operand_ptrs.push_back(data_array.GetBasePointer());
2336   // PadToStatic has a dynamic tensor as input and variadic size of outputs:
2337   // (static_tensor, dynamic_dim_0, dynamic_dim_1, ... )
2338   // Dynamic dimension sizes starts from output index 1.
2339   for (int i = 1; i < hlo->shape().tuple_shapes_size(); ++i) {
2340     // Read from the metadata section of the dynamic input (operand 0).
2341     const Shape& dim_shape = ShapeUtil::GetSubshape(hlo->shape(), {i});
2342     TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32)));
2343     TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dim_size_slice,
2344                         assignment_.GetUniqueSlice(hlo, {i}));
2345     llvm::Value* dest_dim_size_address =
2346         EmitBufferPointer(dim_size_slice, data_shape);
2347     const int64_t dim_index = i - 1;
2348     llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
2349         b_.getInt8Ty(), raw_buffer,
2350         raw_data_size + dim_index * sizeof(int32_t));
2351     llvm::Value* dyn_dim_size = b_.CreateLoad(
2352         b_.getInt32Ty(),
2353         b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()),
2354         "dyn_dim_size");
2355     b_.CreateStore(dyn_dim_size,
2356                    b_.CreateBitCast(dest_dim_size_address,
2357                                     b_.getInt32Ty()->getPointerTo()));
2358     dynamic_dims.push_back(b_.CreateIntCast(dyn_dim_size, b_.getInt64Ty(),
2359                                             /*isSigned=*/true,
2360                                             "i64_dyn_dim_size"));
2361     tuple_operand_ptrs.push_back(dest_dim_size_address);
2362   }
2363 
2364   // Pseudo code for padToStatic:
2365   //
2366   //   for (index i in dynamic_dim)
2367   //     source_index = delinearize(inearize(i, dynamic_dim), static_dim)
2368   //     dest[i] = source[source_index]
2369   auto loop_body_emitter =
2370       [&](const llvm_ir::IrArray::Index& array_index) -> Status {
2371     llvm::Value* linear_index = array_index.Linearize(dynamic_dims, &b_);
2372     llvm_ir::IrArray::Index source_index(linear_index, input_shape, &b_);
2373     llvm::Value* source_element =
2374         GetIrArrayFor(hlo->operand(0)).EmitReadArrayElement(source_index, &b_);
2375     data_array.EmitWriteArrayElement(array_index, source_element, &b_);
2376     return OkStatus();
2377   };
2378   TF_RETURN_IF_ERROR(
2379       llvm_ir::LoopEmitter(loop_body_emitter, input_shape, dynamic_dims, &b_)
2380           .EmitLoop(IrName(hlo)));
2381 
2382   // Emit static tensor and dynamic sizes as one tuple.
2383   llvm_ir::EmitTuple(GetIrArrayFor(hlo), tuple_operand_ptrs, &b_);
2384   return OkStatus();
2385 }
2386 
HandleTopK(HloInstruction * hlo)2387 Status IrEmitter::HandleTopK(HloInstruction* hlo) {
2388   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
2389   const HloInstruction* input = hlo->operand(0);
2390   const int64_t k = hlo->shape().tuple_shapes(0).dimensions().back();
2391   const bool has_batch = hlo->shape().tuple_shapes(0).dimensions_size() == 2;
2392   TF_RET_CHECK(input->shape().element_type() == F32);
2393   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
2394       hlo->shape().tuple_shapes(0).layout()));
2395   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
2396       hlo->shape().tuple_shapes(1).layout()));
2397   TF_RET_CHECK(
2398       LayoutUtil::IsMonotonicWithDim0Major(hlo->operand(0)->shape().layout()));
2399 
2400   TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice values_slice,
2401                       assignment_.GetUniqueSlice(hlo->operand(0), {}));
2402   TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_values_slice,
2403                       assignment_.GetUniqueSlice(hlo, {0}));
2404   TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_indices_slice,
2405                       assignment_.GetUniqueSlice(hlo, {1}));
2406   llvm::Value* values_ptr =
2407       EmitBufferPointer(values_slice, hlo->operand(0)->shape());
2408   llvm::Value* out_values_ptr =
2409       EmitBufferPointer(out_values_slice, hlo->shape().tuple_shapes(0));
2410   llvm::Value* out_indices_ptr =
2411       EmitBufferPointer(out_indices_slice, hlo->shape().tuple_shapes(1));
2412   EmitCallToFunc(
2413       runtime::kTopKF32SymbolName,
2414       {b_.getInt64(has_batch ? input->shape().dimensions(0) : 1),
2415        b_.getInt64(input->shape().dimensions().back()), b_.getInt64(k),
2416        BitCast(values_ptr, b_.getFloatTy()->getPointerTo()),
2417        BitCast(out_values_ptr, b_.getFloatTy()->getPointerTo()),
2418        BitCast(out_indices_ptr, b_.getInt32Ty()->getPointerTo())},
2419       b_.getVoidTy());
2420 
2421   llvm_ir::EmitTuple(GetIrArrayFor(hlo), {out_values_ptr, out_indices_ptr},
2422                      &b_);
2423   return OkStatus();
2424 }
2425 
HandleCustomCall(HloInstruction * custom_call)2426 Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
2427   if (custom_call->custom_call_target() == "PadToStatic") {
2428     return HandlePadToStatic(custom_call);
2429   }
2430   if (custom_call->custom_call_target() == "SliceToDynamic") {
2431     return HandleSliceToDynamic(custom_call);
2432   }
2433   if (custom_call->custom_call_target() == "TopK") {
2434     return HandleTopK(custom_call);
2435   }
2436 
2437   absl::Span<HloInstruction* const> operands(custom_call->operands());
2438   llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
2439   llvm::AllocaInst* operands_alloca =
2440       llvm_ir::EmitAllocaAtFunctionEntryWithCount(
2441           i8_ptr_type, b_.getInt32(operands.size()), "cc_operands_alloca", &b_);
2442   for (size_t i = 0; i < operands.size(); ++i) {
2443     const HloInstruction* operand = operands[i];
2444     llvm::Value* operand_as_i8ptr =
2445         PointerCast(GetEmittedValueFor(operand), i8_ptr_type);
2446     llvm::Value* slot_in_operands_alloca = InBoundsGEP(
2447         operands_alloca->getAllocatedType(), operands_alloca, {b_.getInt64(i)});
2448     Store(operand_as_i8ptr, slot_in_operands_alloca);
2449   }
2450   if (emit_code_for_msan_) {
2451     // Mark the alloca as initialized for msan. The buffer gets read by the
2452     // custom callee, which might be msan-instrumented.
2453     // TODO(b/66051036): Run the msan instrumentation pass instead.
2454     const llvm::DataLayout& dl = module_->getDataLayout();
2455     llvm::Type* intptr_type = b_.getIntPtrTy(dl);
2456     EmitCallToFunc(
2457         "__msan_unpoison",
2458         {PointerCast(operands_alloca, i8_ptr_type),
2459          llvm::ConstantInt::get(
2460              intptr_type, *operands_alloca->getAllocationSizeInBits(dl) / 8)},
2461         b_.getVoidTy());
2462   }
2463 
2464   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call));
2465   // Write the tuple table if the output is a tuple.
2466   if (custom_call->shape().IsTuple()) {
2467     std::vector<llvm::Value*> base_ptrs;
2468     for (int i = 0; i < ShapeUtil::TupleElementCount(custom_call->shape());
2469          ++i) {
2470       const Shape& elem_shape =
2471           ShapeUtil::GetTupleElementShape(custom_call->shape(), i);
2472       TF_RET_CHECK(!elem_shape.IsTuple()) << "Nested tuples not implemented";
2473       TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
2474                           assignment_.GetUniqueSlice(custom_call, {i}));
2475       llvm::Value* addr = EmitBufferPointer(slice, elem_shape);
2476       base_ptrs.push_back(addr);
2477     }
2478     llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_);
2479   }
2480   auto* output_address_arg =
2481       PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type);
2482 
2483   auto typed_custom_call = Cast<HloCustomCallInstruction>(custom_call);
2484   switch (typed_custom_call->api_version()) {
2485     case CustomCallApiVersion::API_VERSION_ORIGINAL:
2486       EmitCallToFunc(custom_call->custom_call_target(),
2487                      {output_address_arg, operands_alloca}, b_.getVoidTy());
2488       break;
2489     case CustomCallApiVersion::API_VERSION_STATUS_RETURNING:
2490       EmitCallToFunc(custom_call->custom_call_target(),
2491                      {output_address_arg, operands_alloca, GetStatusArgument()},
2492                      b_.getVoidTy());
2493       EmitEarlyReturnIfErrorStatus();
2494       break;
2495     case CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED: {
2496       absl::string_view opaque = typed_custom_call->opaque();
2497       EmitCallToFunc(custom_call->custom_call_target(),
2498                      {output_address_arg, operands_alloca,
2499                       b_.CreateGlobalStringPtr(llvm_ir::AsStringRef(opaque)),
2500                       b_.getInt64(opaque.size()), GetStatusArgument()},
2501                      b_.getVoidTy());
2502       EmitEarlyReturnIfErrorStatus();
2503       break;
2504     }
2505     default:
2506       return InternalError(
2507           "Unknown custom-call API version enum value: %d (%s)",
2508           typed_custom_call->api_version(),
2509           CustomCallApiVersion_Name(typed_custom_call->api_version()));
2510   }
2511 
2512   return OkStatus();
2513 }
2514 
HandleWhile(HloInstruction * xla_while)2515 Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
2516   // Precondition: Condition computation must return a scalar bool.
2517   HloComputation* condition = xla_while->while_condition();
2518   TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) &&
2519                condition->root_instruction()->shape().element_type() == PRED)
2520       << "While condition computation must return bool; got: "
2521       << ShapeUtil::HumanString(condition->root_instruction()->shape());
2522   // Check that all while-related buffers share an allocation slice.
2523   TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
2524       xla_while->shape(),
2525       [this, &xla_while](const Shape& /*subshape*/,
2526                          const ShapeIndex& index) -> Status {
2527         auto check = [this](const HloInstruction* a, const HloInstruction* b,
2528                             const ShapeIndex& index) {
2529           const BufferAllocation::Slice slice_a =
2530               assignment_.GetUniqueSlice(a, index).value();
2531           const BufferAllocation::Slice slice_b =
2532               assignment_.GetUniqueSlice(b, index).value();
2533           if (slice_a != slice_b) {
2534             return InternalError(
2535                 "instruction %s %s does not share slice with "
2536                 "instruction %s %s",
2537                 a->ToString(), slice_a.ToString(), b->ToString(),
2538                 slice_b.ToString());
2539           }
2540           return OkStatus();
2541         };
2542         TF_RETURN_IF_ERROR(check(xla_while, xla_while->operand(0), index));
2543         TF_RETURN_IF_ERROR(check(
2544             xla_while, xla_while->while_condition()->parameter_instruction(0),
2545             index));
2546         TF_RETURN_IF_ERROR(
2547             check(xla_while, xla_while->while_body()->parameter_instruction(0),
2548                   index));
2549         TF_RETURN_IF_ERROR(check(
2550             xla_while, xla_while->while_body()->root_instruction(), index));
2551         return OkStatus();
2552       }));
2553 
2554   // Set emitted value to that of 'init' with which it shares an allocation.
2555   const HloInstruction* init = xla_while->operand(0);
2556   emitted_value_[xla_while] = GetEmittedValueFor(init);
2557 
2558   // Generating:
2559   //   while (Condition(while_result)) {
2560   //     // CopyInsertion pass inserts copies which enable 'while_result' to
2561   //     // be passed back in as 'Body' parameter.
2562   //     while_result = Body(while_result);  // Insert
2563   //   }
2564 
2565   // Terminates the current block with a branch to a while header.
2566   llvm::BasicBlock* header_bb = llvm::BasicBlock::Create(
2567       module_->getContext(), IrName(xla_while, "header"),
2568       compute_function_->function());
2569   Br(header_bb);
2570   b_.SetInsertPoint(header_bb);
2571 
2572   // Calls the condition function to determine whether to proceed with the
2573   // body.  It must return a bool, so use the scalar call form.
2574   EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond"));
2575   llvm::Value* while_predicate = ICmpNE(
2576       Load(IrShapeType(
2577                xla_while->while_condition()->root_instruction()->shape()),
2578            GetBufferForGlobalCallReturnValue(*xla_while->while_condition())),
2579       llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0));
2580 
2581   // Branches to the body or to the while exit depending on the condition.
2582   llvm::BasicBlock* body_bb =
2583       llvm::BasicBlock::Create(module_->getContext(), IrName(xla_while, "body"),
2584                                compute_function_->function());
2585   llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create(
2586       module_->getContext(), IrName(xla_while, "exit"));
2587   CondBr(while_predicate, body_bb, exit_bb);
2588 
2589   // Calls the body function from the body block.
2590   b_.SetInsertPoint(body_bb);
2591 
2592   // Calls the body function.
2593   EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body"));
2594 
2595   // Finishes with a branch back to the header.
2596   Br(header_bb);
2597 
2598   // Adds the exit block to the function and sets the insert point there.
2599   compute_function_->function()->getBasicBlockList().push_back(exit_bb);
2600   b_.SetInsertPoint(exit_bb);
2601 
2602   return OkStatus();
2603 }
2604 
EmitFastConcatenate(HloInstruction * concatenate,absl::Span<HloInstruction * const> operands,std::string * failure_reason)2605 StatusOr<bool> IrEmitter::EmitFastConcatenate(
2606     HloInstruction* concatenate, absl::Span<HloInstruction* const> operands,
2607     std::string* failure_reason) {
2608   if (ShouldEmitParallelLoopFor(*concatenate)) {
2609     *failure_reason =
2610         "cannot generate memcpy-based concat for the parallel CPU backend";
2611     return false;
2612   }
2613 
2614   const Shape& output_shape = concatenate->shape();
2615   for (auto* op : operands) {
2616     if (!LayoutUtil::Equal(op->shape().layout(), output_shape.layout())) {
2617       *failure_reason = "operand has mismatching layouts";
2618       return false;
2619     }
2620   }
2621 
2622   // We split the dimensions into three categories: the dimension over which we
2623   // are concatenating (concat_dim), the dimensions that are minor to it
2624   // (inner_dims) and the dimensions that are major to it (outer_dims).
2625 
2626   int64_t concat_dim = concatenate->dimensions(0);
2627   const Layout& output_layout = output_shape.layout();
2628   auto output_min2maj = LayoutUtil::MinorToMajor(output_layout);
2629   auto concat_dim_layout_itr = absl::c_find(output_min2maj, concat_dim);
2630 
2631   std::vector<int64_t> inner_dims(output_min2maj.begin(),
2632                                   concat_dim_layout_itr);
2633   std::vector<int64_t> outer_dims(std::next(concat_dim_layout_itr),
2634                                   output_min2maj.end());
2635 
2636   llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
2637 
2638   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate));
2639   llvm_ir::IrArray target_array = GetIrArrayFor(concatenate);
2640 
2641   llvm_ir::ForLoopNest loops(IrName(concatenate), &b_);
2642   std::vector<llvm::Value*> target_multi_index =
2643       loops.AddLoopsForShapeOnDimensions(output_shape, outer_dims, "concat");
2644   std::replace(target_multi_index.begin(), target_multi_index.end(),
2645                static_cast<llvm::Value*>(nullptr),
2646                static_cast<llvm::Value*>(b_.getInt64(0)));
2647   llvm_ir::IrArray::Index target_index(target_multi_index, output_shape,
2648                                        b_.getInt64Ty());
2649 
2650   if (!outer_dims.empty()) {
2651     SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
2652   }
2653 
2654   PrimitiveType primitive_type = output_shape.element_type();
2655   unsigned primitive_type_size =
2656       ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
2657 
2658   // Contiguous subregions from each operand to the concatenate contribute to a
2659   // contiguous subregion in the target buffer starting at target_region_begin.
2660   llvm::Value* target_region_begin = BitCast(
2661       target_array.EmitArrayElementAddress(target_index, &b_, "target_region"),
2662       i8_ptr_type);
2663   int64_t byte_offset_into_target_region = 0;
2664 
2665   int64_t inner_dims_product =
2666       std::accumulate(inner_dims.begin(), inner_dims.end(), 1l,
2667                       [&](int64_t product, int64_t inner_dim) {
2668                         return product * output_shape.dimensions(inner_dim);
2669                       });
2670 
2671   // For each operand, emit a memcpy from the operand to the target of size
2672   // equal to the product of inner dimensions.
2673   for (HloInstruction* operand : operands) {
2674     const Shape& input_shape = operand->shape();
2675     llvm_ir::IrArray source_array = GetIrArrayFor(operand);
2676     llvm_ir::IrArray::Index source_index(target_multi_index, operand->shape(),
2677                                          b_.getInt64Ty());
2678     llvm::Value* copy_source_address = BitCast(
2679         source_array.EmitArrayElementAddress(source_index, &b_, "src_addr"),
2680         i8_ptr_type);
2681 
2682     llvm::Value* copy_target_address =
2683         GEP(b_.getInt8Ty(), target_region_begin,
2684             b_.getInt64(byte_offset_into_target_region));
2685 
2686     EmitTransferElements(
2687         copy_target_address, copy_source_address,
2688         inner_dims_product * input_shape.dimensions(concat_dim), primitive_type,
2689         target_array, source_array);
2690 
2691     byte_offset_into_target_region += inner_dims_product *
2692                                       input_shape.dimensions(concat_dim) *
2693                                       primitive_type_size;
2694   }
2695 
2696   if (!outer_dims.empty()) {
2697     SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
2698   }
2699 
2700   return true;
2701 }
2702 
EmitPrintf(absl::string_view fmt,absl::Span<llvm::Value * const> arguments)2703 llvm::Value* IrEmitter::EmitPrintf(absl::string_view fmt,
2704                                    absl::Span<llvm::Value* const> arguments) {
2705   llvm::Type* ptr_ty = b_.getInt8Ty()->getPointerTo();
2706   std::vector<llvm::Value*> call_args;
2707   call_args.push_back(b_.CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)));
2708   absl::c_copy(arguments, std::back_inserter(call_args));
2709   return b_.CreateCall(
2710       b_.GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
2711           "printf", llvm::FunctionType::get(b_.getInt32Ty(), {ptr_ty},
2712                                             /*isVarArg=*/true)),
2713       call_args);
2714 }
2715 
EmitPrintfToStderr(absl::string_view fmt,absl::Span<llvm::Value * const> arguments)2716 llvm::Value* IrEmitter::EmitPrintfToStderr(
2717     absl::string_view fmt, absl::Span<llvm::Value* const> arguments) {
2718   llvm::Type* ptr_ty = b_.getInt8Ty()->getPointerTo();
2719   std::vector<llvm::Value*> call_args;
2720   call_args.push_back(b_.CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)));
2721   absl::c_copy(arguments, std::back_inserter(call_args));
2722   return b_.CreateCall(
2723       b_.GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
2724           runtime::kPrintfToStderrSymbolName,
2725           llvm::FunctionType::get(b_.getInt32Ty(), {ptr_ty},
2726                                   /*isVarArg=*/true)),
2727       call_args);
2728 }
2729 
EmitCallToFunc(std::string func_name,const std::vector<llvm::Value * > & arguments,llvm::Type * return_type,bool does_not_throw,bool only_accesses_arg_memory,bool only_accesses_inaccessible_mem_or_arg_mem)2730 llvm::Value* IrEmitter::EmitCallToFunc(
2731     std::string func_name, const std::vector<llvm::Value*>& arguments,
2732     llvm::Type* return_type, bool does_not_throw, bool only_accesses_arg_memory,
2733     bool only_accesses_inaccessible_mem_or_arg_mem) {
2734   std::vector<llvm::Type*> types;
2735   types.reserve(arguments.size());
2736   absl::c_transform(arguments, std::back_inserter(types),
2737                     [&](llvm::Value* val) { return val->getType(); });
2738   llvm::FunctionType* func_type =
2739       llvm::FunctionType::get(return_type, types, /*isVarArg=*/false);
2740   auto func = llvm::dyn_cast<llvm::Function>(
2741       module_->getOrInsertFunction(func_name, func_type).getCallee());
2742   func->setCallingConv(llvm::CallingConv::C);
2743   if (does_not_throw) {
2744     func->setDoesNotThrow();
2745   }
2746   if (only_accesses_arg_memory) {
2747     func->setOnlyAccessesArgMemory();
2748   }
2749   if (only_accesses_inaccessible_mem_or_arg_mem) {
2750     func->setOnlyAccessesInaccessibleMemOrArgMem();
2751   }
2752   return b_.CreateCall(func, arguments);
2753 }
2754 
EmitTransferElements(llvm::Value * target,llvm::Value * source,int64_t element_count,PrimitiveType primitive_type,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & source_array)2755 void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
2756                                      int64_t element_count,
2757                                      PrimitiveType primitive_type,
2758                                      const llvm_ir::IrArray& target_array,
2759                                      const llvm_ir::IrArray& source_array) {
2760   unsigned primitive_type_size =
2761       ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
2762   llvm::Align element_alignment(tensorflow::MathUtil::GCD<unsigned>(
2763       primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type)));
2764   llvm::Type* primitive_llvm_type =
2765       llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
2766   llvm::Type* primitive_ptr_type =
2767       llvm::PointerType::getUnqual(primitive_llvm_type);
2768 
2769   if (element_count == 1) {
2770     auto* load_instruction =
2771         AlignedLoad(primitive_llvm_type, BitCast(source, primitive_ptr_type),
2772                     element_alignment);
2773     source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction);
2774     auto* store_instruction =
2775         AlignedStore(load_instruction, BitCast(target, primitive_ptr_type),
2776                      element_alignment);
2777     target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction);
2778   } else {
2779     auto* memcpy_instruction = b_.CreateMemCpy(
2780         target, /*DstAlign=*/llvm::Align(element_alignment), source,
2781         /*SrcAlign=*/llvm::Align(element_alignment),
2782         element_count * primitive_type_size);
2783 
2784     // The memcpy does the load and the store internally.  The aliasing related
2785     // metadata has to reflect that.
2786     std::map<int, llvm::MDNode*> merged_metadata =
2787         llvm_ir::MergeMetadata(&module_->getContext(), source_array.metadata(),
2788                                target_array.metadata());
2789     for (const auto& kind_md_pair : merged_metadata) {
2790       memcpy_instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
2791     }
2792   }
2793 }
2794 
HandleConcatenate(HloInstruction * concatenate)2795 Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
2796   absl::Span<HloInstruction* const> operands(concatenate->operands());
2797   std::string failure_reason;
2798   TF_ASSIGN_OR_RETURN(
2799       bool successful,
2800       EmitFastConcatenate(concatenate, operands, &failure_reason));
2801   if (successful) {
2802     VLOG(1) << "Emitted fast concatenate for " << concatenate->ToString();
2803     return OkStatus();
2804   }
2805 
2806   VLOG(1) << "Could not emit fast concatenate for " << concatenate->ToString()
2807           << ": " << failure_reason;
2808 
2809   return DefaultAction(concatenate);
2810 }
2811 
HandleConditional(HloInstruction * conditional)2812 Status IrEmitter::HandleConditional(HloInstruction* conditional) {
2813   auto branch_index = conditional->operand(0);
2814   int num_branches = conditional->branch_count();
2815   TF_RET_CHECK(ShapeUtil::IsScalar(branch_index->shape()) &&
2816                (branch_index->shape().element_type() == PRED ||
2817                 branch_index->shape().element_type() == S32))
2818       << "Branch index on a conditional must be scalar bool or int32_t; got: "
2819       << ShapeUtil::HumanString(branch_index->shape());
2820 
2821   for (int b = 0; b < num_branches; ++b) {
2822     HloComputation* br_computation = conditional->branch_computation(b);
2823     TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(),
2824                                   br_computation->root_instruction()->shape()))
2825         << "Shape of conditional should be same as the shape of the " << b
2826         << "th branch computation; got: "
2827         << ShapeUtil::HumanString(conditional->shape()) << " and "
2828         << ShapeUtil::HumanString(br_computation->root_instruction()->shape());
2829   }
2830 
2831   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional));
2832 
2833   if (branch_index->shape().element_type() == PRED) {
2834     // Emit an if-else to LLVM:
2835     //   if (pred)
2836     //     cond_result = true_computation(true_operand)
2837     //   else
2838     //     cond_result = false_computation(false_operand)
2839     llvm::LoadInst* pred_value = Load(
2840         GetIrArrayFor(branch_index).GetBasePointeeType(),
2841         GetIrArrayFor(branch_index).GetBasePointer(), "load_predicate_value");
2842     llvm::Value* pred_cond =
2843         ICmpNE(pred_value,
2844                llvm::ConstantInt::get(
2845                    llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
2846                "boolean_predicate");
2847     llvm_ir::LlvmIfData if_data =
2848         llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_);
2849 
2850     SetToFirstInsertPoint(if_data.true_block, &b_);
2851     EmitGlobalCall(*conditional->branch_computation(0),
2852                    IrName(conditional, "_true"));
2853 
2854     SetToFirstInsertPoint(if_data.false_block, &b_);
2855     EmitGlobalCall(*conditional->branch_computation(1),
2856                    IrName(conditional, "_false"));
2857 
2858     SetToFirstInsertPoint(if_data.after_block, &b_);
2859     return OkStatus();
2860   }
2861   // We emit a switch statement to LLVM:
2862   // switch (branch_index) {
2863   //   default:
2864   //     result = branch_computations[num_branches-1](operands[num_branches-1]);
2865   //     break;
2866   //   case 0:
2867   //     result = branch_computations[0](operands[0]); break;
2868   //   case 1:
2869   //     result = branch_computations[1](operands[1]); break;
2870   //   ...
2871   //   case [[num_branches-2]]:
2872   //     result = branch_computations[num_branches-2](operands[num_branches-2]);
2873   //     break;
2874   // }
2875   llvm::LoadInst* branch_index_value = Load(
2876       GetIrArrayFor(branch_index).GetBasePointeeType(),
2877       GetIrArrayFor(branch_index).GetBasePointer(), "load_branch_index_value");
2878 
2879   auto case_block = b_.GetInsertBlock();
2880   llvm::BasicBlock* after_block;
2881   // Add a terminator to the case block, if necessary.
2882   if (case_block->getTerminator() == nullptr) {
2883     after_block = llvm_ir::CreateBasicBlock(nullptr, "case-after", &b_);
2884     b_.SetInsertPoint(case_block);
2885     b_.CreateBr(after_block);
2886   } else {
2887     after_block =
2888         case_block->splitBasicBlock(b_.GetInsertPoint(), "case-after");
2889   }
2890   // Our basic block should now end with an unconditional branch.  Remove it;
2891   // we're going to replace it with a switch based branch.
2892   case_block->getTerminator()->eraseFromParent();
2893 
2894   // Lower the default branch computation.
2895   auto default_block = llvm_ir::CreateBasicBlock(nullptr, "case-default", &b_);
2896   b_.SetInsertPoint(default_block);
2897   EmitGlobalCall(*conditional->branch_computation(num_branches - 1),
2898                  IrName(conditional, "_default"));
2899   b_.CreateBr(after_block);
2900 
2901   // Prepare the switch (branch_index) { ... } instruction.
2902   b_.SetInsertPoint(case_block);
2903   llvm::SwitchInst* case_inst =
2904       b_.CreateSwitch(branch_index_value, default_block, num_branches - 1);
2905   // Lower each branch's computation.
2906   for (int b = 0; b < num_branches - 1; ++b) {  // last branch is default
2907     // Lower the case b: { ... ; break; } computation.
2908     auto branch_block =
2909         llvm_ir::CreateBasicBlock(nullptr, absl::StrCat("case-branch", b), &b_);
2910     b_.SetInsertPoint(branch_block);
2911     EmitGlobalCall(*conditional->branch_computation(b),
2912                    IrName(conditional, absl::StrCat("_branch", b)));
2913     b_.CreateBr(after_block);
2914     case_inst->addCase(b_.getInt32(b), branch_block);
2915   }
2916 
2917   SetToFirstInsertPoint(after_block, &b_);
2918   return OkStatus();
2919 }
2920 
HandleAfterAll(HloInstruction * after_all)2921 Status IrEmitter::HandleAfterAll(HloInstruction* after_all) {
2922   TF_RET_CHECK(ByteSizeOf(after_all->shape()) == 0);
2923   // No code to generate, but we need to emit an address for book-keeping.
2924   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(after_all));
2925   return OkStatus();
2926 }
2927 
HandleAddDependency(HloInstruction * add_dependency)2928 Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
2929   // AddDedendency just forwards its zero-th operand.
2930   emitted_value_[add_dependency] =
2931       GetEmittedValueFor(add_dependency->operand(0));
2932   return OkStatus();
2933 }
2934 
HandleRng(HloInstruction * rng)2935 Status IrEmitter::HandleRng(HloInstruction* rng) {
2936   return Unimplemented("Rng should be expanded for CPU.");
2937 }
2938 
HandleRngGetAndUpdateState(HloInstruction * rng_state)2939 Status IrEmitter::HandleRngGetAndUpdateState(HloInstruction* rng_state) {
2940   VLOG(2) << "RngGetAndUpdateState: " << rng_state->ToString();
2941   llvm::Value* old_state = llvm_ir::RngGetAndUpdateState(
2942       Cast<HloRngGetAndUpdateStateInstruction>(rng_state)->delta(), module_,
2943       &b_);
2944 
2945   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(rng_state));
2946   llvm::Value* address = GetEmittedValueFor(rng_state);
2947 
2948   // The buffer has an array type while the value has a i128. Cast the
2949   // buffer to i128 type to store the value.
2950   address = BitCast(address, llvm::PointerType::get(
2951                                  old_state->getType()->getScalarType(),
2952                                  address->getType()->getPointerAddressSpace()));
2953   llvm::StoreInst* store = Store(old_state, address);
2954   store->setAlignment(llvm::Align(IrEmitter::MinimumAlignmentForPrimitiveType(
2955       rng_state->shape().element_type())));
2956 
2957   return OkStatus();
2958 }
2959 
FinishVisit(HloInstruction * root)2960 Status IrEmitter::FinishVisit(HloInstruction* root) {
2961   // When this method is called, we should have already emitted an IR value for
2962   // the root (return) op. The IR value holds the address of the buffer holding
2963   // the value. If the root is a constant or parameter, we perform a memcpy from
2964   // this buffer to the retval buffer of the computation. Otherwise, there's
2965   // nothing to do since the result was already written directly into the output
2966   // buffer.
2967   VLOG(2) << "FinishVisit root: " << root->ToString();
2968   if (root->opcode() == HloOpcode::kOutfeed) {
2969     VLOG(2) << "  outfeed with value: "
2970             << llvm_ir::DumpToString(*GetEmittedValueFor(root->operand(0)));
2971   } else {
2972     VLOG(2) << "  value: " << llvm_ir::DumpToString(*GetEmittedValueFor(root));
2973   }
2974 
2975   auto record_complete_computation = [&](llvm::Value* prof_counter) {
2976     if (prof_counter) {
2977       profiling_state_.RecordCompleteComputation(&b_, prof_counter);
2978     }
2979   };
2980 
2981   // For the entry computation this increment is cumulative of embedded
2982   // computations since it includes cycles spent in computations invoked by
2983   // While, Call etc.
2984   record_complete_computation(GetProfileCounterFor(*root->parent()));
2985   return OkStatus();
2986 }
2987 
2988 template <typename T>
GetProfileCounterCommon(const T & hlo,const absl::flat_hash_map<const T *,int64_t> & profile_index_map)2989 llvm::Value* IrEmitter::GetProfileCounterCommon(
2990     const T& hlo,
2991     const absl::flat_hash_map<const T*, int64_t>& profile_index_map) {
2992   auto it = profile_index_map.find(&hlo);
2993   if (it == profile_index_map.end()) {
2994     return nullptr;
2995   }
2996 
2997   int64_t prof_counter_idx = it->second;
2998   std::string counter_name = IrName("prof_counter", hlo.name());
2999   return GEP(b_.getInt64Ty(), GetProfileCountersArgument(),
3000              b_.getInt64(prof_counter_idx), counter_name);
3001 }
3002 
GetProfileCounterFor(const HloInstruction & instruction)3003 llvm::Value* IrEmitter::GetProfileCounterFor(
3004     const HloInstruction& instruction) {
3005   return GetProfileCounterCommon<HloInstruction>(instruction,
3006                                                  instruction_to_profile_idx_);
3007 }
3008 
GetProfileCounterFor(const HloComputation & computation)3009 llvm::Value* IrEmitter::GetProfileCounterFor(
3010     const HloComputation& computation) {
3011   return GetProfileCounterCommon<HloComputation>(computation,
3012                                                  computation_to_profile_idx_);
3013 }
3014 
UpdateProfileCounter(llvm::IRBuilder<> * b,llvm::Value * prof_counter,llvm::Value * cycle_end,llvm::Value * cycle_start)3015 void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b,
3016                                                      llvm::Value* prof_counter,
3017                                                      llvm::Value* cycle_end,
3018                                                      llvm::Value* cycle_start) {
3019   auto* cycle_diff = b->CreateSub(cycle_end, cycle_start);
3020   llvm::LoadInst* old_cycle_count = b->CreateLoad(
3021       llvm::cast<llvm::GetElementPtrInst>(prof_counter)->getSourceElementType(),
3022       prof_counter, "old_cycle_count");
3023   auto* new_cycle_count =
3024       b->CreateAdd(cycle_diff, old_cycle_count, "new_cycle_count");
3025   b->CreateStore(new_cycle_count, prof_counter);
3026 }
3027 
ReadCycleCounter(llvm::IRBuilder<> * b)3028 llvm::Value* IrEmitter::ProfilingState::ReadCycleCounter(llvm::IRBuilder<>* b) {
3029   llvm::Module* module = b->GetInsertBlock()->getModule();
3030   if (!use_rdtscp_) {
3031     llvm::Function* func_llvm_readcyclecounter =
3032         llvm::Intrinsic::getDeclaration(module,
3033                                         llvm::Intrinsic::readcyclecounter);
3034     return b->CreateCall(func_llvm_readcyclecounter);
3035   }
3036   llvm::Function* func_llvm_x86_rdtscp =
3037       llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::x86_rdtscp);
3038   llvm::Value* rdtscp_call = b->CreateCall(func_llvm_x86_rdtscp);
3039   return b->CreateExtractValue(rdtscp_call, {0});
3040 }
3041 
RecordCycleStart(llvm::IRBuilder<> * b,HloInstruction * hlo)3042 void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* b,
3043                                                  HloInstruction* hlo) {
3044   auto* cycle_start = ReadCycleCounter(b);
3045   cycle_start->setName(IrName(hlo, "cycle_start"));
3046   cycle_starts_[hlo] = cycle_start;
3047   if (first_read_cycle_start_ == nullptr) {
3048     first_read_cycle_start_ = cycle_start;
3049   }
3050 }
3051 
RecordCycleDelta(llvm::IRBuilder<> * b,HloInstruction * hlo,llvm::Value * prof_counter)3052 void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* b,
3053                                                  HloInstruction* hlo,
3054                                                  llvm::Value* prof_counter) {
3055   auto* cycle_end = ReadCycleCounter(b);
3056   cycle_end->setName(IrName(hlo, "cycle_end"));
3057   auto* cycle_start = cycle_starts_[hlo];
3058   UpdateProfileCounter(b, prof_counter, cycle_end, cycle_start);
3059   last_read_cycle_end_ = cycle_end;
3060 }
3061 
RecordCompleteComputation(llvm::IRBuilder<> * b,llvm::Value * prof_counter)3062 void IrEmitter::ProfilingState::RecordCompleteComputation(
3063     llvm::IRBuilder<>* b, llvm::Value* prof_counter) {
3064   if (last_read_cycle_end_ && first_read_cycle_start_) {
3065     UpdateProfileCounter(b, prof_counter, last_read_cycle_end_,
3066                          first_read_cycle_start_);
3067   }
3068 }
3069 
EmitTracingStart(llvm::IRBuilder<> * b,HloInstruction * hlo,llvm::Value * run_options)3070 void IrEmitter::TracingState::EmitTracingStart(llvm::IRBuilder<>* b,
3071                                                HloInstruction* hlo,
3072                                                llvm::Value* run_options) {
3073   if (!enabled_) {
3074     return;
3075   }
3076 
3077   llvm::Type* int8_ptr_type = b->getInt8Ty()->getPointerTo();
3078   llvm::Type* void_ptr_type =
3079       int8_ptr_type;  // LLVM does not have a void*, we use an int8_t* instead.
3080   llvm::FunctionType* fn_type =
3081       llvm::FunctionType::get(b->getInt64Ty(), {void_ptr_type, int8_ptr_type},
3082                               /*isVarArg=*/false);
3083 
3084   llvm::Function* function = b->GetInsertBlock()->getParent();
3085   llvm::Module* module = function->getParent();
3086   const char* fn_name = runtime::kTracingStartSymbolName;
3087   llvm::FunctionCallee trace_func =
3088       module->getOrInsertFunction(fn_name, fn_type);
3089   if (auto* fn = llvm::dyn_cast<llvm::Function>(trace_func.getCallee())) {
3090     fn->setCallingConv(llvm::CallingConv::C);
3091     fn->setDoesNotThrow();
3092     fn->setOnlyAccessesArgMemory();
3093   }
3094   auto* hlo_name = b->CreateGlobalStringPtr(hlo->name());
3095   auto* activity_id =
3096       b->CreateCall(trace_func, {b->CreateBitCast(run_options, void_ptr_type),
3097                                  b->CreateBitCast(hlo_name, int8_ptr_type)});
3098   activity_id->setName(IrName(hlo, "activity_id"));
3099   activity_ids_[hlo] = activity_id;
3100 }
3101 
EmitTracingEnd(llvm::IRBuilder<> * b,HloInstruction * hlo,llvm::Value * run_options)3102 void IrEmitter::TracingState::EmitTracingEnd(llvm::IRBuilder<>* b,
3103                                              HloInstruction* hlo,
3104                                              llvm::Value* run_options) {
3105   if (!enabled_) {
3106     return;
3107   }
3108 
3109   llvm::Type* void_ptr_type =
3110       b->getInt8Ty()->getPointerTo();  // LLVM does not have a void*, we use an
3111                                        // int8_t* instead.
3112   llvm::FunctionType* fn_type =
3113       llvm::FunctionType::get(b->getVoidTy(), {void_ptr_type, b->getInt64Ty()},
3114                               /*isVarArg=*/false);
3115 
3116   llvm::Function* function = b->GetInsertBlock()->getParent();
3117   llvm::Module* module = function->getParent();
3118   const char* fn_name = runtime::kTracingEndSymbolName;
3119   llvm::FunctionCallee trace_func =
3120       module->getOrInsertFunction(fn_name, fn_type);
3121   if (auto* fn = llvm::dyn_cast<llvm::Function>(trace_func.getCallee())) {
3122     fn->setCallingConv(llvm::CallingConv::C);
3123     fn->setDoesNotThrow();
3124     fn->setOnlyAccessesArgMemory();
3125   }
3126   auto* activity_id = activity_ids_.at(hlo);
3127   b->CreateCall(trace_func,
3128                 {b->CreateBitCast(run_options, void_ptr_type), activity_id});
3129 }
3130 
3131 namespace {
IsHloVeryCheap(const HloInstruction * hlo)3132 bool IsHloVeryCheap(const HloInstruction* hlo) {
3133   return hlo->opcode() == HloOpcode::kBitcast ||
3134          hlo->opcode() == HloOpcode::kTuple ||
3135          hlo->opcode() == HloOpcode::kGetTupleElement ||
3136          hlo->opcode() == HloOpcode::kParameter ||
3137          hlo->opcode() == HloOpcode::kConstant ||
3138          hlo->opcode() == HloOpcode::kReplicaId;
3139 }
3140 }  // namespace
3141 
Preprocess(HloInstruction * hlo)3142 Status IrEmitter::Preprocess(HloInstruction* hlo) {
3143   VLOG(3) << "Visiting: " << hlo->ToString();
3144   // When profiling is enabled, trace the same HLOs that the profiler does.
3145   if (instruction_to_profile_idx_.count(hlo) ||
3146       (hlo_module_config_.cpu_traceme_enabled() && !IsHloVeryCheap(hlo))) {
3147     tracing_state_.EmitTracingStart(&b_, hlo,
3148                                     GetExecutableRunOptionsArgument());
3149     profiling_state_.RecordCycleStart(&b_, hlo);
3150   }
3151   return OkStatus();
3152 }
3153 
Postprocess(HloInstruction * hlo)3154 Status IrEmitter::Postprocess(HloInstruction* hlo) {
3155   if (auto* prof_counter = GetProfileCounterFor(*hlo)) {
3156     profiling_state_.RecordCycleDelta(&b_, hlo, prof_counter);
3157   }
3158   // When profiling is enabled, trace the same HLOs that the profiler does.
3159   if (instruction_to_profile_idx_.count(hlo) ||
3160       (hlo_module_config_.cpu_traceme_enabled() && !IsHloVeryCheap(hlo))) {
3161     tracing_state_.EmitTracingEnd(&b_, hlo, GetExecutableRunOptionsArgument());
3162   }
3163   return OkStatus();
3164 }
3165 
GetIrArrayFor(const HloInstruction * hlo)3166 llvm_ir::IrArray IrEmitter::GetIrArrayFor(const HloInstruction* hlo) {
3167   llvm::Value* value_for_op = GetEmittedValueFor(hlo);
3168 
3169   llvm::Type* ir_type = IrShapeType(hlo->shape());
3170   llvm_ir::IrArray array(value_for_op, ir_type, hlo->shape());
3171   AddAliasingInformationToIrArray(*hlo, &array);
3172   return array;
3173 }
3174 
GetIrArraysForOperandsOf(const HloInstruction * hlo)3175 std::vector<llvm_ir::IrArray> IrEmitter::GetIrArraysForOperandsOf(
3176     const HloInstruction* hlo) {
3177   std::vector<llvm_ir::IrArray> arrays;
3178   std::transform(
3179       hlo->operands().begin(), hlo->operands().end(),
3180       std::back_inserter(arrays),
3181       [&](const HloInstruction* operand) { return GetIrArrayFor(operand); });
3182   return arrays;
3183 }
3184 
GetEmittedValueFor(const HloInstruction * hlo)3185 llvm::Value* IrEmitter::GetEmittedValueFor(const HloInstruction* hlo) {
3186   auto it = emitted_value_.find(hlo);
3187   if (it == emitted_value_.end()) {
3188     LOG(FATAL) << "could not find emitted value for: " << hlo->ToString();
3189   }
3190   return it->second;
3191 }
3192 
IrShapeType(const Shape & shape)3193 llvm::Type* IrEmitter::IrShapeType(const Shape& shape) {
3194   return llvm_ir::ShapeToIrType(shape, module_);
3195 }
3196 
GetProfileCountersArgument()3197 llvm::Value* IrEmitter::GetProfileCountersArgument() {
3198   return compute_function_->profile_counters_arg();
3199 }
3200 
GetStatusArgument()3201 llvm::Value* IrEmitter::GetStatusArgument() {
3202   return compute_function_->status_arg();
3203 }
3204 
GetBufferTableArgument()3205 llvm::Value* IrEmitter::GetBufferTableArgument() {
3206   return compute_function_->buffer_table_arg();
3207 }
3208 
GetExecutableRunOptionsArgument()3209 llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() {
3210   return compute_function_->exec_run_options_arg();
3211 }
3212 
GetReturnBlock()3213 llvm::BasicBlock* IrEmitter::GetReturnBlock() {
3214   return compute_function_->return_block();
3215 }
3216 
EmitEarlyReturnIfErrorStatus()3217 void IrEmitter::EmitEarlyReturnIfErrorStatus() {
3218   // Use the runtime helper to get the success/failure state as a boolean.
3219   llvm::Value* succeeded =
3220       EmitCallToFunc(runtime::kStatusIsSuccessSymbolName, {GetStatusArgument()},
3221                      b_.getInt1Ty(), /*does_not_throw=*/true,
3222                      /*only_accesses_arg_memory=*/true);
3223   llvm_ir::EmitEarlyReturn(succeeded, &b_, GetReturnBlock());
3224 }
3225 
EmitThreadLocalBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)3226 llvm::Value* IrEmitter::EmitThreadLocalBufferPointer(
3227     const BufferAllocation::Slice& slice, const Shape& target_shape) {
3228   const BufferAllocation& allocation = *slice.allocation();
3229   llvm::Value* tempbuf_address = [&]() -> llvm::Value* {
3230     auto param_it =
3231         computation_parameter_allocations_.find(slice.allocation()->index());
3232     if (param_it != computation_parameter_allocations_.end()) {
3233       int64_t param_number = param_it->second;
3234       // We have to access the parameter at offset param_number in the params
3235       // array. The code generated here is equivalent to this C code:
3236       //
3237       //   i8* param_address_untyped = params[param_number];
3238       //   Param* param_address_typed = (Param*)param_address_untyped;
3239       //
3240       // Where Param is the actual element type of the underlying buffer (for
3241       // example, float for an XLA F32 element type).
3242       llvm::Value* params = compute_function_->parameters_arg();
3243       llvm::Value* param_address_offset = llvm_ir::EmitBufferIndexingGEP(
3244           params, b_.getInt8PtrTy(), param_number, &b_);
3245       llvm::LoadInst* param_address_untyped =
3246           Load(b_.getInt8PtrTy(), param_address_offset);
3247 
3248       if (!target_shape.IsOpaque()) {
3249         AttachAlignmentMetadataForLoad(param_address_untyped, target_shape);
3250         AttachDereferenceableMetadataForLoad(param_address_untyped,
3251                                              target_shape);
3252       }
3253       return param_address_untyped;
3254     }
3255 
3256     // Thread-local allocations should only be assigned a single buffer.
3257     const auto& assigned_buffers = allocation.assigned_buffers();
3258     CHECK_EQ(1, assigned_buffers.size());
3259     const Shape& shape = assigned_buffers.begin()->first->shape();
3260 
3261     std::pair<llvm::Function*, BufferAllocation::Slice> key = {
3262         compute_function_->function(), slice};
3263     auto buf_it = thread_local_buffers_.find(key);
3264     if (buf_it == thread_local_buffers_.end()) {
3265       llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry(
3266           IrShapeType(shape), absl::StrCat("thread_local", slice.ToString()),
3267           &b_, MinimumAlignmentForShape(target_shape));
3268       auto it_inserted_pair = thread_local_buffers_.insert({key, buffer});
3269       CHECK(it_inserted_pair.second);
3270       buf_it = it_inserted_pair.first;
3271     }
3272     return buf_it->second;
3273   }();
3274   return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo());
3275 }
3276 
EmitGlobalBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)3277 llvm::Value* IrEmitter::EmitGlobalBufferPointer(
3278     const BufferAllocation::Slice& slice, const Shape& target_shape) {
3279   const BufferAllocation& allocation = *slice.allocation();
3280   llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
3281       GetBufferTableArgument(), b_.getInt8PtrTy(), slice.index(), &b_);
3282   llvm::LoadInst* tempbuf_address_base =
3283       Load(b_.getInt8PtrTy(), tempbuf_address_ptr);
3284   if (hlo_module_config_.debug_options()
3285           .xla_llvm_enable_invariant_load_metadata()) {
3286     tempbuf_address_base->setMetadata(
3287         llvm::LLVMContext::MD_invariant_load,
3288         llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{}));
3289   }
3290   AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size());
3291   AttachDereferenceableMetadataForLoad(tempbuf_address_base, allocation.size());
3292 
3293   llvm::Value* tempbuf_address_untyped = tempbuf_address_base;
3294   if (slice.offset() > 0) {
3295     // Adjust the address to account for the slice offset.
3296     tempbuf_address_untyped = InBoundsGEP(b_.getInt8Ty(), tempbuf_address_base,
3297                                           b_.getInt64(slice.offset()));
3298   }
3299   return BitCast(tempbuf_address_untyped,
3300                  IrShapeType(target_shape)->getPointerTo());
3301 }
3302 
EmitBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)3303 llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice,
3304                                           const Shape& target_shape) {
3305   if (slice.allocation()->is_thread_local()) {
3306     return EmitThreadLocalBufferPointer(slice, target_shape);
3307   } else if (slice.allocation()->is_constant()) {
3308     return BitCast(
3309         FindOrDie(constant_buffer_to_global_, slice.allocation()->index()),
3310         IrShapeType(target_shape)->getPointerTo());
3311   } else {
3312     return EmitGlobalBufferPointer(slice, target_shape);
3313   }
3314 }
3315 
EmitTargetAddressForOp(const HloInstruction * op)3316 Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
3317   const Shape& target_shape = op->shape();
3318   TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
3319                       assignment_.GetUniqueTopLevelSlice(op));
3320   llvm::Value* addr = EmitBufferPointer(slice, target_shape);
3321   addr->setName(IrName(op));
3322   emitted_value_[op] = addr;
3323   return OkStatus();
3324 }
3325 
EmitTargetElementLoop(HloInstruction * target_op,const llvm_ir::ElementGenerator & element_generator)3326 Status IrEmitter::EmitTargetElementLoop(
3327     HloInstruction* target_op,
3328     const llvm_ir::ElementGenerator& element_generator) {
3329   return EmitTargetElementLoop(target_op, /*desc=*/"", element_generator);
3330 }
3331 
EmitTargetElementLoop(HloInstruction * target_op,absl::string_view desc,const llvm_ir::ElementGenerator & element_generator)3332 Status IrEmitter::EmitTargetElementLoop(
3333     HloInstruction* target_op, absl::string_view desc,
3334     const llvm_ir::ElementGenerator& element_generator) {
3335   VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString();
3336 
3337   const Shape& target_shape = target_op->shape();
3338   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op));
3339   llvm_ir::IrArray target_array = GetIrArrayFor(target_op);
3340 
3341   if (target_shape.IsTuple() &&
3342       (target_op->opcode() == HloOpcode::kFusion ||
3343        target_op->opcode() == HloOpcode::kReduce ||
3344        target_op->opcode() == HloOpcode::kReduceWindow)) {
3345     // For multiple outputs fusion, we need to emit each operand and the root.
3346     TF_RET_CHECK(num_dynamic_loop_bounds_ == 0);
3347     std::vector<llvm_ir::IrArray> output_arrays;
3348     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(target_shape); ++i) {
3349       TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
3350                           assignment_.GetUniqueSlice(target_op, {i}));
3351       const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i});
3352       llvm::Value* op_target_address = EmitBufferPointer(slice, element_shape);
3353       llvm::Type* op_target_type = IrShapeType(element_shape);
3354       output_arrays.push_back(
3355           llvm_ir::IrArray(op_target_address, op_target_type, element_shape));
3356     }
3357     TF_RETURN_IF_ERROR(
3358         llvm_ir::LoopEmitter(element_generator, output_arrays, &b_)
3359             .EmitLoop(IrName(target_op)));
3360 
3361     std::vector<llvm::Value*> tuple_operand_ptrs;
3362     for (int64_t i = 0; i < output_arrays.size(); ++i) {
3363       tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
3364     }
3365     llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &b_);
3366 
3367   } else {
3368     if (ShouldEmitParallelLoopFor(*target_op)) {
3369       // Emit code to read dynamic loop bounds from compute function argument.
3370       std::vector<std::pair<llvm::Value*, llvm::Value*>> dynamic_loop_bounds =
3371           compute_function_->GetDynamicLoopBounds();
3372       // Emit parallel loop with dynamic loop bounds for most-major dimensions.
3373       TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array,
3374                                              &dynamic_loop_bounds, &b_)
3375                              .EmitLoop(IrName(target_op)));
3376     } else {
3377       TF_RETURN_IF_ERROR(
3378           llvm_ir::LoopEmitter(element_generator, target_array, &b_)
3379               .EmitLoop(IrName(target_op)));
3380     }
3381   }
3382   return OkStatus();
3383 }
3384 
EmitMemcpy(const HloInstruction & source,const HloInstruction & destination)3385 Status IrEmitter::EmitMemcpy(const HloInstruction& source,
3386                              const HloInstruction& destination) {
3387   llvm::Value* source_value = GetEmittedValueFor(&source);
3388   llvm::Value* destination_value = GetEmittedValueFor(&destination);
3389   int64_t source_size = ByteSizeOf(source.shape());
3390   // TODO(b/63762267): Be more aggressive about specifying alignment.
3391   MemCpy(destination_value, /*DstAlign=*/llvm::Align(1), source_value,
3392          /*SrcAlign=*/llvm::Align(1), source_size);
3393   return OkStatus();
3394 }
3395 
ElementTypesSameAndSupported(const HloInstruction & instruction,absl::Span<const HloInstruction * const> operands,absl::Span<const PrimitiveType> supported_types)3396 Status IrEmitter::ElementTypesSameAndSupported(
3397     const HloInstruction& instruction,
3398     absl::Span<const HloInstruction* const> operands,
3399     absl::Span<const PrimitiveType> supported_types) {
3400   for (auto operand : operands) {
3401     TF_RET_CHECK(
3402         ShapeUtil::SameElementType(operands[0]->shape(), operand->shape()));
3403   }
3404 
3405   TF_RET_CHECK(!operands.empty());
3406   PrimitiveType primitive_type = operands[0]->shape().element_type();
3407   if (!absl::c_linear_search(supported_types, primitive_type)) {
3408     return Unimplemented("unsupported operand type %s in op %s",
3409                          PrimitiveType_Name(primitive_type),
3410                          HloOpcodeString(instruction.opcode()));
3411   }
3412   return OkStatus();
3413 }
3414 
DefaultAction(HloInstruction * hlo)3415 Status IrEmitter::DefaultAction(HloInstruction* hlo) {
3416   ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
3417   for (const HloInstruction* operand : hlo->operands()) {
3418     operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
3419       return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_);
3420     };
3421   }
3422   CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
3423   return EmitTargetElementLoop(
3424       hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
3425 }
3426 
EmitScalarReturningThreadLocalCall(const HloComputation & callee,absl::Span<llvm::Value * const> parameters,absl::string_view name)3427 llvm::Value* IrEmitter::EmitScalarReturningThreadLocalCall(
3428     const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
3429     absl::string_view name) {
3430   std::vector<llvm::Value*> return_value =
3431       EmitThreadLocalCall(callee, parameters, name, /*is_reducer=*/false);
3432   CHECK_EQ(return_value.size(), 1);
3433   return return_value[0];
3434 }
3435 
EmitThreadLocalCall(const HloComputation & callee,absl::Span<llvm::Value * const> parameters,absl::string_view name,bool is_reducer)3436 std::vector<llvm::Value*> IrEmitter::EmitThreadLocalCall(
3437     const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
3438     absl::string_view name, bool is_reducer) {
3439   CHECK(absl::c_binary_search(thread_local_computations_, &callee));
3440   const Shape& return_shape = callee.root_instruction()->shape();
3441   bool is_scalar_return = ShapeUtil::IsScalar(return_shape);
3442   bool is_tuple_of_scalars_return =
3443       return_shape.IsTuple() &&
3444       absl::c_all_of(return_shape.tuple_shapes(), [&](const Shape& shape) {
3445         return ShapeUtil::IsScalar(shape);
3446       });
3447   CHECK(is_scalar_return || is_tuple_of_scalars_return);
3448 
3449   std::vector<llvm::Value*> parameter_addrs;
3450   for (llvm::Value* parameter : parameters) {
3451     CHECK(!parameter->getType()->isPointerTy());
3452     llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry(
3453         parameter->getType(), "arg_addr", &b_);
3454     Store(parameter, parameter_addr);
3455     parameter_addrs.push_back(parameter_addr);
3456   }
3457 
3458   llvm::Type* return_value_buffer_type =
3459       llvm_ir::ShapeToIrType(return_shape, module_);
3460   std::string retval_alloca_name = absl::StrCat(name, "_return_value_addr");
3461   int retval_alignment =
3462       is_scalar_return
3463           ? MinimumAlignmentForPrimitiveType(return_shape.element_type())
3464           : 0;
3465   llvm::AllocaInst* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
3466       return_value_buffer_type, retval_alloca_name, &b_, retval_alignment);
3467 
3468   std::vector<llvm::Value*> allocas_for_returned_scalars;
3469   if (is_scalar_return) {
3470     allocas_for_returned_scalars.push_back(return_value_buffer);
3471   } else {
3472     constexpr int max_tuple_size = 1000;
3473     CHECK_LT(return_shape.tuple_shapes_size(), max_tuple_size)
3474         << "Multivalue function can not return more than 1000 elements to avoid"
3475         << " stack smashing";
3476     allocas_for_returned_scalars =
3477         llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_);
3478     llvm_ir::IrArray tuple_array(return_value_buffer, return_value_buffer_type,
3479                                  return_shape);
3480 
3481     EmitTuple(tuple_array, allocas_for_returned_scalars, &b_);
3482   }
3483 
3484   Call(
3485       FindOrDie(emitted_functions_,
3486                 ComputationToEmit{&callee, allow_reassociation_ || is_reducer}),
3487       GetArrayFunctionCallArguments(
3488           parameter_addrs, &b_, name,
3489           /*return_value_buffer=*/return_value_buffer,
3490           /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
3491           /*buffer_table_arg=*/
3492           llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
3493           /*status_arg=*/GetStatusArgument(),
3494           /*profile_counters_arg=*/GetProfileCountersArgument()));
3495 
3496   if (ComputationTransitivelyContainsCustomCall(&callee)) {
3497     EmitEarlyReturnIfErrorStatus();
3498   }
3499 
3500   std::vector<llvm::Value*> returned_scalars;
3501   returned_scalars.reserve(allocas_for_returned_scalars.size());
3502   for (llvm::Value* addr : allocas_for_returned_scalars) {
3503     returned_scalars.push_back(
3504         Load(llvm::cast<llvm::AllocaInst>(addr)->getAllocatedType(), addr));
3505   }
3506   return returned_scalars;
3507 }
3508 
EmitGlobalCall(const HloComputation & callee,absl::string_view name)3509 void IrEmitter::EmitGlobalCall(const HloComputation& callee,
3510                                absl::string_view name) {
3511   CHECK(absl::c_binary_search(global_computations_, &callee));
3512 
3513   Call(FindOrDie(emitted_functions_,
3514                  ComputationToEmit{&callee, allow_reassociation_}),
3515        GetArrayFunctionCallArguments(
3516            /*parameter_addresses=*/{}, &b_, name,
3517            /*return_value_buffer=*/
3518            llvm::Constant::getNullValue(b_.getInt8PtrTy()),
3519            /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
3520            /*buffer_table_arg=*/GetBufferTableArgument(),
3521            /*status_arg=*/GetStatusArgument(),
3522            /*profile_counters_arg=*/GetProfileCountersArgument()));
3523 
3524   if (ComputationTransitivelyContainsCustomCall(&callee)) {
3525     EmitEarlyReturnIfErrorStatus();
3526   }
3527 }
3528 
GetBufferForGlobalCallReturnValue(const HloComputation & callee)3529 llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
3530     const HloComputation& callee) {
3531   const HloInstruction* root_inst = callee.root_instruction();
3532   if (root_inst->opcode() == HloOpcode::kOutfeed) {
3533     return llvm::Constant::getNullValue(b_.getInt8PtrTy());
3534   }
3535 
3536   const BufferAllocation::Slice root_buffer =
3537       assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie();
3538   return EmitBufferPointer(root_buffer, root_inst->shape());
3539 }
3540 
BindFusionArguments(const HloInstruction * fusion,FusedIrEmitter * fused_emitter)3541 void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
3542                                     FusedIrEmitter* fused_emitter) {
3543   for (int i = 0; i < fusion->operand_count(); i++) {
3544     const HloInstruction* operand = fusion->operand(i);
3545     fused_emitter->BindGenerator(
3546         *fusion->fused_parameter(i),
3547         [this, operand](llvm_ir::IrArray::Index index) {
3548           return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_);
3549         });
3550   }
3551 }
3552 
3553 }  // namespace cpu
3554 }  // namespace xla
3555