1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
17
18 #include <stddef.h>
19 #include <stdint.h>
20
21 #include <algorithm>
22 #include <iterator>
23 #include <limits>
24 #include <memory>
25 #include <string>
26 #include <utility>
27 #include <vector>
28
29 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
30 #include "absl/cleanup/cleanup.h"
31 #include "absl/container/flat_hash_map.h"
32 #include "absl/container/flat_hash_set.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/strings/str_format.h"
35 #include "absl/strings/str_join.h"
36 #include "absl/strings/string_view.h"
37 #include "absl/types/span.h"
38 #include "llvm/CodeGen/TargetRegisterInfo.h"
39 #include "llvm/CodeGen/TargetSubtargetInfo.h"
40 #include "llvm/IR/BasicBlock.h"
41 #include "llvm/IR/Constants.h"
42 #include "llvm/IR/FMF.h"
43 #include "llvm/IR/GlobalVariable.h"
44 #include "llvm/IR/Instructions.h"
45 #include "llvm/IR/Intrinsics.h"
46 #include "llvm/IR/IntrinsicsX86.h"
47 #include "llvm/IR/LLVMContext.h"
48 #include "llvm/IR/Value.h"
49 #include "tensorflow/compiler/xla/layout_util.h"
50 #include "tensorflow/compiler/xla/map_util.h"
51 #include "tensorflow/compiler/xla/primitive_util.h"
52 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
53 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
54 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
55 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
56 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
57 #include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h"
58 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
59 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
60 #include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h"
61 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
62 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
63 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
64 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
65 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
66 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
67 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
68 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
69 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
70 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
71 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
72 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_type_conversion_util.h"
73 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
74 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
75 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
76 #include "tensorflow/compiler/xla/shape_util.h"
77 #include "tensorflow/compiler/xla/status_macros.h"
78 #include "tensorflow/compiler/xla/types.h"
79 #include "tensorflow/compiler/xla/util.h"
80 #include "tensorflow/compiler/xla/window_util.h"
81 #include "tensorflow/compiler/xla/xla_data.pb.h"
82 #include "tensorflow/core/lib/core/errors.h"
83 #include "tensorflow/core/lib/math/math_util.h"
84 #include "tensorflow/core/platform/logging.h"
85
86 namespace xla {
87
88 namespace {
89 using llvm_ir::IrName;
90 using llvm_ir::SetToFirstInsertPoint;
91 } // namespace
92
93 namespace cpu {
94
IrEmitter(mlir::MLIRContext * mlir_context,const HloModule & hlo_module,const BufferAssignment & assignment,llvm::Module * llvm_module,absl::flat_hash_map<const HloInstruction *,int64_t> instruction_to_profile_idx,absl::flat_hash_map<const HloComputation *,int64_t> computation_to_profile_idx,absl::flat_hash_map<const HloComputation *,bool> computation_transitively_contains_custom_call,const TargetMachineFeatures * target_machine_features,bool emit_code_for_msan)95 IrEmitter::IrEmitter(mlir::MLIRContext* mlir_context,
96 const HloModule& hlo_module,
97 const BufferAssignment& assignment,
98 llvm::Module* llvm_module,
99 absl::flat_hash_map<const HloInstruction*, int64_t>
100 instruction_to_profile_idx,
101 absl::flat_hash_map<const HloComputation*, int64_t>
102 computation_to_profile_idx,
103 absl::flat_hash_map<const HloComputation*, bool>
104 computation_transitively_contains_custom_call,
105 const TargetMachineFeatures* target_machine_features,
106 bool emit_code_for_msan)
107 : assignment_(assignment),
108 module_(llvm_module),
109 arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()),
110 b_(llvm_module->getContext()),
111 mlir_context_(mlir_context),
112 instruction_to_profile_idx_(std::move(instruction_to_profile_idx)),
113 computation_to_profile_idx_(std::move(computation_to_profile_idx)),
114 computation_transitively_contains_custom_call_(
115 std::move(computation_transitively_contains_custom_call)),
116 alias_analysis_(hlo_module, assignment, &llvm_module->getContext()),
117 hlo_module_config_(hlo_module.config()),
118 is_top_level_computation_(false),
119 target_machine_features_(*target_machine_features),
120 emit_code_for_msan_(emit_code_for_msan) {
121 b_.setFastMathFlags(llvm_ir::GetCpuFastMathFlags(hlo_module_config_));
122 Status s = GatherComputationsByAllocationType(
123 &hlo_module, &thread_local_computations_, &global_computations_);
124 absl::c_sort(thread_local_computations_);
125 absl::c_sort(global_computations_);
126 TF_CHECK_OK(s) << "Should have failed buffer assignment.";
127 }
128
EmitThreadLocalFunctionEpilogue(HloComputation * computation)129 void IrEmitter::EmitThreadLocalFunctionEpilogue(HloComputation* computation) {
130 llvm::Argument* out_parameter = compute_function_->result_arg();
131 llvm_ir::IrArray root_value = GetIrArrayFor(computation->root_instruction());
132 const Shape& return_shape = computation->root_instruction()->shape();
133
134 if (ShapeUtil::IsScalar(return_shape)) {
135 llvm::Value* ret_value =
136 Load(root_value.GetBasePointeeType(), root_value.GetBasePointer(),
137 "load_ret_value");
138 Store(ret_value,
139 BitCast(out_parameter, root_value.GetBasePointer()->getType()));
140 } else {
141 CHECK(return_shape.IsTuple());
142
143 llvm::Type* tuple_type = llvm_ir::ShapeToIrType(return_shape, module_);
144 llvm::Type* tuple_type_lvalue = tuple_type->getPointerTo();
145 llvm::Value* tuple_lvalue = BitCast(out_parameter, tuple_type_lvalue);
146
147 for (int i = 0; i < return_shape.tuple_shapes_size(); i++) {
148 const Shape& element_shape = return_shape.tuple_shapes(i);
149 llvm::Value* destination = llvm_ir::EmitGetTupleElement(
150 element_shape,
151 /*index=*/i,
152 /*alignment=*/MinimumAlignmentForShape(element_shape), tuple_lvalue,
153 tuple_type, &b_);
154
155 llvm::Value* source = llvm_ir::EmitGetTupleElement(
156 element_shape,
157 /*index=*/i,
158 /*alignment=*/MinimumAlignmentForShape(element_shape),
159 root_value.GetBasePointer(), root_value.GetBasePointeeType(), &b_);
160
161 Store(Load(IrShapeType(element_shape), source), destination);
162 }
163 }
164 }
165
EmitComputation(HloComputation * computation,const std::string & function_name_prefix,bool is_top_level_computation,absl::Span<HloInstruction * const> instruction_order,bool allow_reassociation)166 StatusOr<llvm::Function*> IrEmitter::EmitComputation(
167 HloComputation* computation, const std::string& function_name_prefix,
168 bool is_top_level_computation,
169 absl::Span<HloInstruction* const> instruction_order,
170 bool allow_reassociation) {
171 std::string function_name = name_uniquer_.GetUniqueName(function_name_prefix);
172 VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]";
173 is_top_level_computation_ = is_top_level_computation;
174 allow_reassociation_ = allow_reassociation;
175 num_dynamic_loop_bounds_ = 0;
176 if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
177 num_dynamic_loop_bounds_ =
178 computation->root_instruction()->outer_dimension_partitions().size();
179 }
180
181 if (computation->root_instruction()->opcode() != HloOpcode::kOutfeed) {
182 TF_ASSIGN_OR_RETURN(
183 computation_root_allocation_,
184 assignment_.GetUniqueTopLevelSlice(computation->root_instruction()));
185 }
186
187 bool has_thread_local_param = false;
188 for (const HloInstruction* param : computation->parameter_instructions()) {
189 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice param_slice,
190 assignment_.GetUniqueTopLevelSlice(param));
191 has_thread_local_param |= param_slice.allocation()->is_thread_local();
192 computation_parameter_allocations_[param_slice.allocation()->index()] =
193 param->parameter_number();
194 }
195
196 InitializeIrFunction(function_name);
197 // The rdtscp instruction is x86 specific. We will fallback to LLVM's generic
198 // readcyclecounter if it is unavailable.
199 bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 ||
200 arch_type_ == llvm::Triple::ArchType::x86_64;
201 profiling_state_ = ProfilingState(use_rdtscp);
202
203 tracing_state_.set_enabled(
204 computation->parent()->config().cpu_traceme_enabled());
205
206 llvm::IRBuilderBase::FastMathFlagGuard guard(*builder());
207 llvm::FastMathFlags flags = builder()->getFastMathFlags();
208 flags.setAllowReassoc(flags.allowReassoc() || allow_reassociation);
209 builder()->setFastMathFlags(flags);
210
211 TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, instruction_order));
212 llvm::Function* ir_function = compute_function_->function();
213 InsertOrDie(&emitted_functions_,
214 ComputationToEmit{computation, allow_reassociation}, ir_function);
215 // Delete 'compute_function', finalizing 'ir_function' and restoring caller
216 // IR insert point.
217
218 // Function epilogue: copying the value over to either the return register,
219 // or values pointing from the return register. If we have an allocation for
220 // the result, we can cue on whether it is thread_local. If it is a constant,
221 // we use the function parameters' allocations to identify a thread_local
222 // function.
223 const BufferAllocation* root_allocation =
224 computation_root_allocation_.allocation();
225 if (root_allocation &&
226 (root_allocation->is_thread_local() ||
227 (root_allocation->is_constant() && has_thread_local_param))) {
228 EmitThreadLocalFunctionEpilogue(computation);
229 }
230
231 // Destructor for compute_function_ terminates the LLVM function definition.
232 compute_function_.reset();
233 computation_root_allocation_ = BufferAllocation::Slice();
234 computation_parameter_allocations_.clear();
235 return ir_function;
236 }
237
InitializeIrFunction(const std::string & function_name)238 void IrEmitter::InitializeIrFunction(const std::string& function_name) {
239 // Functions with local linkage get an inlining bonus. Because we know
240 // a-priori that embedded functions (non-entry functions) will not have its
241 // name resolved, give it local linkage.
242 llvm::Function::LinkageTypes linkage =
243 is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage
244 : llvm::GlobalValue::InternalLinkage;
245 // Create and initialize new IrFunction.
246 compute_function_.reset(new IrFunction(function_name, linkage,
247 hlo_module_config_, module_, &b_,
248 num_dynamic_loop_bounds_));
249 }
250
~IrEmitter()251 IrEmitter::~IrEmitter() {}
252
HandleBitcast(HloInstruction * bitcast)253 Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
254 VLOG(2) << "HandleBitcast: " << bitcast->ToString();
255 emitted_value_[bitcast] =
256 BitCast(GetEmittedValueFor(bitcast->operand(0)),
257 IrShapeType(bitcast->shape())->getPointerTo(), IrName(bitcast));
258 return OkStatus();
259 }
260
EmitGlobalForLiteral(const Literal & literal)261 llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) {
262 llvm::Constant* initializer =
263 llvm_ir::ConvertLiteralToIrConstant(literal, module_);
264 llvm::GlobalVariable* result_global = new llvm::GlobalVariable(
265 /*Module=*/*module_,
266 /*Type=*/initializer->getType(),
267 /*isConstant=*/true,
268 /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
269 /*Initializer=*/initializer,
270 /*Name=*/"");
271 result_global->setAlignment(
272 llvm::Align(MinimumAlignmentForShape(literal.shape())));
273 result_global->setUnnamedAddr(llvm::GlobalVariable::UnnamedAddr::Global);
274 return llvm::ConstantExpr::getBitCast(
275 result_global, IrShapeType(literal.shape())->getPointerTo());
276 }
277
EmitConstantGlobals()278 Status IrEmitter::EmitConstantGlobals() {
279 for (const BufferAllocation& allocation : assignment_.Allocations()) {
280 if (!allocation.is_constant()) {
281 continue;
282 }
283
284 const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation);
285 llvm::Constant* global_for_const;
286 auto it = emitted_literals_.find(&literal);
287 if (it != emitted_literals_.end()) {
288 global_for_const = it->second;
289 } else {
290 global_for_const = EmitGlobalForLiteral(literal);
291 InsertOrDie(&emitted_literals_, &literal, global_for_const);
292 }
293
294 InsertOrDie(&constant_buffer_to_global_, allocation.index(),
295 global_for_const);
296 }
297
298 return OkStatus();
299 }
300
HandleConstant(HloInstruction * constant)301 Status IrEmitter::HandleConstant(HloInstruction* constant) {
302 VLOG(2) << "HandleConstant: " << constant->ToString();
303 // IrEmitter::EmitConstantGlobals has already taken care of emitting the body
304 // of the constant.
305 return EmitTargetAddressForOp(constant);
306 }
307
HandleCopy(HloInstruction * copy)308 Status IrEmitter::HandleCopy(HloInstruction* copy) {
309 if (copy->shape().IsTuple() ||
310 (copy->shape().IsArray() &&
311 LayoutUtil::Equal(copy->operand(0)->shape().layout(),
312 copy->shape().layout()))) {
313 // If the layouts are equal this is just a memcpy. kCopy shallow copies a
314 // tuple so just memcpy the top-level buffer for tuples.
315 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy));
316 return EmitMemcpy(*(copy->operand(0)), *copy);
317 } else if (copy->shape().IsArray()) {
318 // Use the elemental emitter for array shapes.
319 return DefaultAction(copy);
320 }
321 return Unimplemented("unsupported operand type %s for copy instruction",
322 PrimitiveType_Name(copy->shape().element_type()));
323 }
324
325 // Calculate the alignment of a buffer allocated for a given primitive type.
MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type)326 int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) {
327 int64_t byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
328 DCHECK_GE(byte_size, 0);
329 // Largest scalar is a complex128 so we don't need to worry about the
330 // int64_t->int truncation here.
331 DCHECK_LE(byte_size, 16);
332
333 // Allocations may be 8-byte aligned if part of a small block.
334 return std::min(int64_t{8}, byte_size);
335 }
336
ByteSizeOf(const Shape & shape) const337 int64_t IrEmitter::ByteSizeOf(const Shape& shape) const {
338 return llvm_ir::ByteSizeOf(shape, module_->getDataLayout());
339 }
340
341 // Calculate the alignment of a buffer allocated for a given shape.
MinimumAlignmentForShape(const Shape & shape)342 int IrEmitter::MinimumAlignmentForShape(const Shape& shape) {
343 if (ShapeUtil::IsScalar(shape)) {
344 return MinimumAlignmentForPrimitiveType(shape.element_type());
345 }
346
347 int64_t buffer_size = ByteSizeOf(shape);
348 DCHECK_GE(buffer_size, 0);
349 DCHECK_LE(buffer_size, SIZE_MAX);
350
351 return target_machine_features_.minimum_alignment_for_allocation(buffer_size);
352 }
353
AttachAlignmentMetadataForLoad(llvm::LoadInst * load,const Shape & shape)354 void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
355 const Shape& shape) {
356 int alignment = MinimumAlignmentForShape(shape);
357 if (alignment > 1) {
358 llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
359 }
360 }
361
AttachAlignmentMetadataForLoad(llvm::LoadInst * load,int64_t buffer_size)362 void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
363 int64_t buffer_size) {
364 int alignment =
365 target_machine_features_.minimum_alignment_for_allocation(buffer_size);
366 if (alignment > 1) {
367 llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
368 }
369 }
370
AttachDereferenceableMetadataForLoad(llvm::LoadInst * load,const Shape & shape)371 void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
372 const Shape& shape) {
373 AttachDereferenceableMetadataForLoad(load, ByteSizeOf(shape));
374 }
375
AttachDereferenceableMetadataForLoad(llvm::LoadInst * load,int64_t buffer_size)376 void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
377 int64_t buffer_size) {
378 if (buffer_size > 0) {
379 llvm_ir::SetDereferenceableMetadataForLoad(load, buffer_size);
380 }
381 }
382
HandleGetTupleElement(HloInstruction * get_tuple_element)383 Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
384 // A tuple is an array of pointers, one for each operand. Each pointer points
385 // to the output buffer of its corresponding operand. A GetTupleElement
386 // instruction forwards a pointer to the tuple element buffer at the given
387 // index.
388 const HloInstruction* operand = get_tuple_element->operand(0);
389 const Shape& shape = get_tuple_element->shape();
390 emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement(
391 shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape),
392 GetEmittedValueFor(operand), IrShapeType(operand->shape()), &b_);
393 return OkStatus();
394 }
395
HandleSelect(HloInstruction * select)396 Status IrEmitter::HandleSelect(HloInstruction* select) {
397 auto pred = select->operand(0);
398 TF_RET_CHECK(pred->shape().element_type() == PRED);
399 return DefaultAction(select);
400 }
401
HandleInfeed(HloInstruction * instruction)402 Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
403 HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
404 VLOG(2) << "HandleInfeed: " << infeed->ToString();
405
406 // The infeed operation produces a two-element tuple containing data and a
407 // token value. HloInfeedInstruction::infeed_shape gives us the data shape.
408 const Shape& data_shape = infeed->infeed_shape();
409 DCHECK(ShapeUtil::Equal(data_shape,
410 ShapeUtil::GetTupleElementShape(infeed->shape(), 0)));
411 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(infeed));
412
413 // Write the tuple index table.
414 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
415 assignment_.GetUniqueSlice(infeed, {0}));
416 llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape);
417 llvm::Type* data_type = IrShapeType(data_shape);
418 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice,
419 assignment_.GetUniqueSlice(infeed, {1}));
420 llvm::Value* token_address = EmitBufferPointer(
421 token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1));
422 llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_);
423
424 if (data_shape.IsTuple()) {
425 TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape));
426
427 // For a tuple, we first copy each of the internal elements to
428 // their corresponding target locations. We then construct the
429 // tuple outer buffer containing pointers to the internal
430 // elements.
431 std::vector<llvm::Value*> tuple_element_addresses;
432 for (int i = 0; i < data_shape.tuple_shapes_size(); ++i) {
433 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer,
434 assignment_.GetUniqueSlice(infeed, {0, i}));
435
436 const Shape& tuple_element_shape =
437 ShapeUtil::GetTupleElementShape(data_shape, i);
438
439 // Only the outer tuple buffer's target address is obtained from
440 // GetEmittedValueFor, to handle the case when Infeed is the root
441 // instruction. Target addresses for internal elements can be obtained
442 // from EmitBufferPointer.
443 llvm::Value* tuple_element_address =
444 EmitBufferPointer(buffer, tuple_element_shape);
445
446 TF_RETURN_IF_ERROR(EmitXfeedTransfer(
447 XfeedKind::kInfeed, tuple_element_shape, tuple_element_address));
448
449 tuple_element_addresses.push_back(tuple_element_address);
450 }
451
452 llvm_ir::EmitTuple(llvm_ir::IrArray(data_address, data_type, data_shape),
453 tuple_element_addresses, &b_);
454 } else {
455 TF_RETURN_IF_ERROR(
456 EmitXfeedTransfer(XfeedKind::kInfeed, data_shape, data_address));
457 }
458
459 return OkStatus();
460 }
461
EmitXfeedTransfer(XfeedKind kind,const Shape & shape,llvm::Value * program_buffer_address)462 Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
463 llvm::Value* program_buffer_address) {
464 int64_t length = ByteSizeOf(shape);
465 if (length < 0 || length > std::numeric_limits<int32_t>::max()) {
466 return InvalidArgument(
467 "xfeed (infeed or outfeed) buffer length %d is outside the valid "
468 "size range",
469 length);
470 }
471 int32_t length_32 = static_cast<int32_t>(length);
472
473 int32_t shape_length;
474 TF_ASSIGN_OR_RETURN(
475 llvm::Value * shape_ptr,
476 llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_));
477
478 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
479
480 const char* acquire_func_name =
481 kind == XfeedKind::kInfeed
482 ? runtime::kAcquireInfeedBufferForDequeueSymbolName
483 : runtime::kAcquireOutfeedBufferForPopulationSymbolName;
484
485 // Implementation note: this call informs the runtime that it wants a
486 // buffer of size exactly 'length_32', and the runtime is responsible for
487 // check-failing the process if there is a mismatch, versus passing us
488 // back a buffer that we might overrun.
489 llvm::Value* acquired_pointer =
490 EmitCallToFunc(acquire_func_name,
491 {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
492 shape_ptr, b_.getInt32(shape_length)},
493 i8_ptr_type);
494 if (kind == XfeedKind::kInfeed) {
495 // Copy to the program buffer address from the acquired buffer.
496 MemCpy(program_buffer_address, /*DstAlign=*/llvm::Align(1),
497 acquired_pointer,
498 /*SrcAlign=*/llvm::Align(1), length_32);
499 } else {
500 // Outfeed -- copy from the in-program address to the acquired buffer.
501 MemCpy(acquired_pointer, /*DstAlign=*/llvm::Align(1),
502 program_buffer_address,
503 /*SrcAlign=*/llvm::Align(1), length_32);
504 if (emit_code_for_msan_) {
505 // Mark the outfed data as initialized for msan. The buffer gets read by
506 // the host code, which might be msan-instrumented.
507 // TODO(b/66051036): Run the msan instrumentation pass instead.
508 const llvm::DataLayout& dl = module_->getDataLayout();
509 llvm::Type* intptr_type = b_.getIntPtrTy(dl);
510 EmitCallToFunc(
511 "__msan_unpoison",
512 {acquired_pointer, llvm::ConstantInt::get(intptr_type, length)},
513 b_.getVoidTy());
514 }
515 }
516
517 const char* release_func_name =
518 kind == XfeedKind::kInfeed
519 ? runtime::kReleaseInfeedBufferAfterDequeueSymbolName
520 : runtime::kReleaseOutfeedBufferAfterPopulationSymbolName;
521 EmitCallToFunc(release_func_name,
522 {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
523 acquired_pointer, shape_ptr, b_.getInt32(shape_length)},
524 b_.getVoidTy());
525
526 return OkStatus();
527 }
528
HandleOutfeed(HloInstruction * outfeed)529 Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
530 // Outfeed produces no useful result, but it does return a token[] that can be
531 // threaded through to other side effecting operations to ensure ordering. In
532 // the IR emitter we treat this token as a normal u8[] and thus need to insert
533 // an entry for it in emitted_value_.
534 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(outfeed));
535
536 HloInstruction* operand = outfeed->operands()[0];
537 const Shape& operand_shape = operand->shape();
538
539 llvm::Value* value = GetEmittedValueFor(operand);
540 if (!operand_shape.IsTuple()) {
541 return EmitXfeedTransfer(XfeedKind::kOutfeed, operand_shape, value);
542 }
543
544 TF_RET_CHECK(!ShapeUtil::IsNestedTuple(operand_shape));
545
546 for (int i = 0; i < operand_shape.tuple_shapes_size(); ++i) {
547 const Shape& tuple_element_shape =
548 ShapeUtil::GetTupleElementShape(operand_shape, i);
549 llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement(
550 tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape),
551 value, IrShapeType(operand_shape), &b_);
552 TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed,
553 tuple_element_shape, tuple_element));
554 }
555
556 return OkStatus();
557 }
558
HandleSort(HloInstruction * hlo)559 Status IrEmitter::HandleSort(HloInstruction* hlo) {
560 const HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
561 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort));
562 Shape keys_shape = sort->keys()->shape();
563 PrimitiveType keys_type = keys_shape.element_type();
564 if (!primitive_util::IsArrayType(keys_type)) {
565 return Unimplemented("Element type %s not supported in the Sort op on CPU.",
566 PrimitiveType_Name(keys_type));
567 }
568 std::vector<llvm::Value*> destination_addresses(sort->operand_count());
569 for (int64_t i = 0; i < sort->operand_count(); ++i) {
570 ShapeIndex shape_index =
571 sort->values_count() > 0 ? ShapeIndex({i}) : ShapeIndex({});
572 const HloInstruction* operand = sort->operand(i);
573 // We assume that the layout of all involved operands and outputs is the
574 // same.
575 TF_RET_CHECK(
576 LayoutUtil::LayoutsInShapesEqual(keys_shape, operand->shape()));
577 TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
578 keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index)));
579
580 // The sort is implemented in-place, therefore we first copy the operand
581 // buffer to the output buffer if they are not the same.
582 auto destination_buffer = GetAllocationSlice(*sort, shape_index);
583 destination_addresses[i] =
584 EmitBufferPointer(destination_buffer, operand->shape());
585 auto source_address = GetAllocationSlice(*operand);
586 if (destination_buffer != source_address) {
587 int64_t primitive_type_size =
588 ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type());
589 auto source_buffer = GetEmittedValueFor(operand);
590 int64_t size = ByteSizeOf(operand->shape());
591 MemCpy(destination_addresses[i],
592 /*DstAlign=*/llvm::Align(primitive_type_size), source_buffer,
593 /*SrcAlign=*/llvm::Align(primitive_type_size), size);
594 }
595 }
596
597 // Normalize the shape and the dimension to sort.
598 Shape normalized_keys_shape =
599 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(keys_shape);
600 auto logical_to_physical =
601 LayoutUtil::MakeLogicalToPhysical(keys_shape.layout());
602 TF_RET_CHECK(sort->sort_dimension() < logical_to_physical.size());
603 int64_t physical_dimension_to_sort =
604 logical_to_physical[sort->sort_dimension()];
605
606 int64_t sort_dimension_elements =
607 normalized_keys_shape.dimensions(physical_dimension_to_sort);
608 int64_t higher_dimensions = 1;
609 for (int64_t i = 0; i < physical_dimension_to_sort; ++i) {
610 higher_dimensions *= normalized_keys_shape.dimensions(i);
611 }
612 int64_t lower_dimensions = 1;
613 for (int64_t i = normalized_keys_shape.rank() - 1;
614 i > physical_dimension_to_sort; --i) {
615 lower_dimensions *= normalized_keys_shape.dimensions(i);
616 }
617
618 CHECK(absl::c_binary_search(thread_local_computations_, sort->to_apply()));
619 llvm::Value* values = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
620 b_.getInt8PtrTy(), b_.getInt32(sort->operand_count()), "cc_values_alloca",
621 &b_);
622 llvm::Value* sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
623 b_.getInt32Ty(), b_.getInt32(sort->operand_count()), "cc_sizes_alloca",
624 &b_);
625 for (int64_t i = 0; i < sort->operand_count(); ++i) {
626 llvm::Value* value_as_i8ptr =
627 PointerCast(destination_addresses[i], b_.getInt8PtrTy());
628 llvm::Value* slot_in_values_alloca =
629 ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i);
630 Store(value_as_i8ptr, slot_in_values_alloca);
631 llvm::Value* slot_in_sizes_alloca =
632 ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i);
633 llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType(
634 sort->operand(i)->shape().element_type()));
635 Store(size, slot_in_sizes_alloca);
636 }
637
638 auto less_than_function =
639 FindOrDie(emitted_functions_,
640 ComputationToEmit{sort->to_apply(), allow_reassociation_});
641 EmitCallToFunc(
642 runtime::kKeyValueSortSymbolName,
643 {b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
644 b_.getInt64(lower_dimensions), values,
645 b_.getInt32(sort->operand_count()), sizes, b_.getInt1(sort->is_stable()),
646 GetExecutableRunOptionsArgument(), GetProfileCountersArgument(),
647 less_than_function},
648 b_.getVoidTy());
649
650 if (sort->values_count() > 0) {
651 llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_);
652 }
653 return OkStatus();
654 }
655
HandleTuple(HloInstruction * tuple)656 Status IrEmitter::HandleTuple(HloInstruction* tuple) {
657 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple));
658 llvm::SmallVector<llvm::Value*> base_ptrs;
659 for (auto operand : tuple->operands()) {
660 base_ptrs.push_back(GetEmittedValueFor(operand));
661 }
662 llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &b_);
663 return OkStatus();
664 }
665
HandleReduceWindow(HloInstruction * reduce_window)666 Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
667 // Pseudo code for reduce window:
668 //
669 // for (coordinates O in the output)
670 // value = init_value;
671 // for (coordinates W in the window)
672 // for each index i:
673 // input coordinates I_i = O_i * stride_i + W_i - pad_low_i
674 // if I within bounds of input:
675 // value = function(value, input(I));
676 // output(O) = value;
677 //
678 // This is completely un-optimized and just here to have something
679 // that works.
680 bool saved_allow_reassociation = allow_reassociation_;
681 allow_reassociation_ = true;
682 Status status = DefaultAction(reduce_window);
683 allow_reassociation_ = saved_allow_reassociation;
684 return status;
685 }
686
HandleSelectAndScatter(HloInstruction * select_and_scatter)687 Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
688 CHECK_EQ(select_and_scatter->operand_count(), 3);
689 const auto operand = select_and_scatter->operand(0);
690 const auto source = select_and_scatter->operand(1);
691 const auto init_value = select_and_scatter->operand(2);
692 const Window& window = select_and_scatter->window();
693 PrimitiveType operand_element_type = operand->shape().element_type();
694 const int64_t rank = operand->shape().rank();
695 CHECK_EQ(rank, source->shape().rank());
696 CHECK_EQ(rank, window.dimensions_size());
697
698 // TODO(b/31410564): Implement dilation for select-and-scatter.
699 if (window_util::HasDilation(window)) {
700 return Unimplemented(
701 "Dilation for SelectAndScatter is not implemented on CPU. ");
702 }
703
704 // Pseudo code for select-and-scatter:
705 //
706 // initialized_flag is initially off for every window, and is turned on after
707 // the first iteration is completed and the first operand value is selected.
708 //
709 // output(*) = init_value
710 // for (coordinates S in the source) {
711 // initialized_flag = false
712 // for (coordinates W in the window) {
713 // I = S * stride + W - pad_low
714 // if I within bounds of operand:
715 // if !initialized_flag or select(selected_value, operand(I)) == false:
716 // selected_value = operand(I)
717 // selected_index = I
718 // initialized_flag = true
719 // }
720 // output(selected_index) = scatter(output(selected_index), source(S))
721 // }
722 //
723
724 // Initialize the output array with the given init_value.
725 TF_RETURN_IF_ERROR(EmitTargetElementLoop(
726 select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"),
727 [this, init_value](const llvm_ir::IrArray::Index& target_index) {
728 llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
729 return Load(IrShapeType(init_value->shape()), init_value_addr);
730 }));
731
732 // Create a loop to iterate over the source array to scatter to the output.
733 llvm_ir::ForLoopNest source_loops(IrName(select_and_scatter), &b_);
734 const llvm_ir::IrArray::Index source_index =
735 source_loops.AddLoopsForShape(source->shape(), "source");
736 SetToFirstInsertPoint(source_loops.GetInnerLoopBodyBasicBlock(), &b_);
737
738 // Allocate space to keep the currently selected value, its index, and
739 // the boolean initialized_flag, which is initially set to false.
740 llvm::AllocaInst* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
741 llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
742 "selected_value_address", &b_,
743 MinimumAlignmentForPrimitiveType(operand_element_type));
744 llvm::AllocaInst* selected_index_address =
745 llvm_ir::EmitAllocaAtFunctionEntryWithCount(
746 b_.getInt64Ty(), b_.getInt32(rank), "selected_index_address", &b_);
747 llvm::AllocaInst* initialized_flag_address =
748 llvm_ir::EmitAllocaAtFunctionEntry(b_.getInt1Ty(),
749 "initialized_flag_address", &b_);
750 Store(b_.getInt1(false), initialized_flag_address);
751
752 // Create the inner loop to iterate over the window.
753 llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), &b_);
754 llvm::SmallVector<int64_t> window_size;
755 for (const auto& dim : window.dimensions()) {
756 window_size.push_back(dim.size());
757 }
758 const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape(
759 ShapeUtil::MakeShape(operand_element_type, window_size), "window");
760 SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), &b_);
761
762 // Compute the operand index to visit and evaluate the condition whether the
763 // operand index is within the bounds. The unsigned comparison includes
764 // checking whether the operand index >= 0.
765 llvm::SmallVector<llvm::Value*> operand_multi_index(source_index.size());
766 llvm::Value* in_bounds_condition = b_.getTrue();
767 for (int64_t i = 0; i < rank; ++i) {
768 llvm::Value* strided_index =
769 NSWMul(source_index[i], b_.getInt64(window.dimensions(i).stride()));
770 operand_multi_index[i] =
771 NSWSub(NSWAdd(strided_index, window_index[i]),
772 b_.getInt64(window.dimensions(i).padding_low()));
773 llvm::Value* index_condition =
774 ICmpULT(operand_multi_index[i],
775 b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
776 in_bounds_condition = And(in_bounds_condition, index_condition);
777 }
778 CHECK(in_bounds_condition != nullptr);
779
780 // Only need to do something if the operand index is within the bounds. First
781 // check if the initialized_flag is set.
782 llvm_ir::LlvmIfData if_in_bounds =
783 llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
784 SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
785 llvm_ir::LlvmIfData if_initialized =
786 llvm_ir::EmitIfThenElse(Load(initialized_flag_address->getAllocatedType(),
787 initialized_flag_address),
788 "initialized", &b_);
789
790 // If the initialized_flag is false, initialize the selected value and index
791 // with the currently visiting operand.
792 SetToFirstInsertPoint(if_initialized.false_block, &b_);
793 const auto save_operand_index =
794 [&](const llvm_ir::IrArray::Index& operand_index) {
795 for (int64_t i = 0; i < rank; ++i) {
796 llvm::Value* selected_index_address_slot =
797 InBoundsGEP(selected_index_address->getAllocatedType(),
798 selected_index_address, {b_.getInt32(i)});
799 Store(operand_index[i], selected_index_address_slot);
800 }
801 };
802 llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
803 llvm_ir::IrArray::Index operand_index(
804 operand_multi_index, operand_array.GetShape(), b_.getInt64Ty());
805 llvm::Value* operand_data =
806 operand_array.EmitReadArrayElement(operand_index, &b_);
807 Store(operand_data, selected_value_address);
808 save_operand_index(operand_index);
809 Store(b_.getInt1(true), initialized_flag_address);
810
811 // If the initialized_flag is true, call the `select` function to potentially
812 // update the selected value and index with the currently visiting operand.
813 SetToFirstInsertPoint(if_initialized.true_block, &b_);
814 llvm::Value* operand_address =
815 operand_array.EmitArrayElementAddress(operand_index, &b_);
816 llvm::Value* operand_element =
817 Load(operand_array.GetElementLlvmType(), operand_address);
818 llvm::Value* result = EmitScalarReturningThreadLocalCall(
819 *select_and_scatter->select(),
820 {Load(selected_value_address->getAllocatedType(), selected_value_address),
821 operand_element},
822 "select_function");
823
824 // If the 'select' function returns false, update the selected value and the
825 // index to the currently visiting operand.
826 llvm::Value* cond = ICmpNE(
827 result,
828 llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
829 "boolean_predicate");
830 llvm_ir::LlvmIfData if_select_lhs =
831 llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
832 SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
833 Store(Load(operand_array.GetElementLlvmType(), operand_address),
834 selected_value_address);
835 save_operand_index(operand_index);
836
837 // After iterating over the window elements, scatter the source element to
838 // the selected index of the output. The value we store at the output
839 // location is computed by calling the `scatter` function with the source
840 // value and the current output value.
841 SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), &b_);
842 llvm::SmallVector<llvm::Value*> selected_multi_index;
843 for (int64_t i = 0; i < rank; ++i) {
844 const std::vector<llvm::Value*> gep_index = {b_.getInt32(i)};
845 llvm::Value* selected_index_address_slot =
846 InBoundsGEP(selected_index_address->getAllocatedType(),
847 selected_index_address, gep_index);
848 llvm::Type* type = llvm::GetElementPtrInst::getIndexedType(
849 selected_index_address->getAllocatedType(), gep_index);
850 selected_multi_index.push_back(Load(type, selected_index_address_slot));
851 }
852 llvm_ir::IrArray source_array(GetIrArrayFor(source));
853 llvm::Value* source_value =
854 source_array.EmitReadArrayElement(source_index, &b_);
855 llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter));
856 llvm_ir::IrArray::Index selected_index(
857 selected_multi_index, output_array.GetShape(), source_index.GetType());
858 llvm::Value* output_value =
859 output_array.EmitReadArrayElement(selected_index, &b_);
860 llvm::Value* scatter_value = EmitScalarReturningThreadLocalCall(
861 *select_and_scatter->scatter(), {output_value, source_value},
862 "scatter_function");
863 output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_);
864
865 SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_);
866 return OkStatus();
867 }
868
HandleDot(HloInstruction * dot)869 Status IrEmitter::HandleDot(HloInstruction* dot) {
870 auto lhs = dot->operand(0);
871 auto rhs = dot->operand(1);
872 TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
873 /*instruction=*/*dot, /*operands=*/{lhs, rhs},
874 /*supported_types=*/
875 {PRED, S8, U8, S16, U16, S32, U32, S64, U64, F16, F32, F64, C64, C128}));
876 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
877
878 if (dnums.lhs_contracting_dimensions_size() != 1) {
879 // This is disallowed by ShapeInference today.
880 return Unimplemented(
881 "Dot with multiple contracting dimensions not implemented.");
882 }
883
884 llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs));
885 llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs));
886
887 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dot));
888 llvm_ir::IrArray target_array = GetIrArrayFor(dot);
889
890 VLOG(2) << "HandleDot: ";
891 VLOG(2) << " lhs operand: "
892 << llvm_ir::DumpToString(*lhs_array.GetBasePointer());
893 VLOG(2) << " rhs operand: "
894 << llvm_ir::DumpToString(*rhs_array.GetBasePointer());
895 VLOG(2) << " target: "
896 << llvm_ir::DumpToString(*target_array.GetBasePointer());
897
898 // Dot operation is complicated so we delegate to a helper class.
899 return EmitDotOperation(*dot, target_array, lhs_array, rhs_array,
900 /*addend_array=*/nullptr,
901 GetExecutableRunOptionsArgument(), &b_, mlir_context_,
902 hlo_module_config_, target_machine_features_);
903 }
904
HandleConvolution(HloInstruction * convolution)905 Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
906 auto lhs = convolution->operand(0);
907 auto rhs = convolution->operand(1);
908 TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
909 /*instruction=*/*convolution, /*operands=*/{lhs, rhs},
910 /*supported_types=*/
911 {PRED, S8, U8, S16, U16, S32, U32, S64, U64, F16, F32, F64, C64, C128}));
912
913 // TODO(tonywy): Add PotentiallyImplementedAsMKLConvolution to support
914 // different data layouts.
915 if (PotentiallyImplementedAsEigenConvolution(*convolution,
916 target_machine_features_)) {
917 const Shape& lhs_shape = lhs->shape();
918 const Shape& rhs_shape = rhs->shape();
919 const Shape& convolution_shape = convolution->shape();
920 // The input, kernel and output agree with respect to layout.
921 if (LayoutUtil::IsMonotonicWithDim0Major(lhs_shape.layout()) &&
922 LayoutUtil::IsMonotonicWithDim0Major(rhs_shape.layout()) &&
923 LayoutUtil::IsMonotonicWithDim0Major(convolution_shape.layout())) {
924 // We lower 1D convolutions into calls to the same Eigen function as 2D
925 // convolutions, except that we pretend that the 1D convolution is really
926 // a 2D convolution with the missing dimension set to 1. We also adjust
927 // the padding, dilation parameters as needed.
928 bool one_dim_convolution = lhs_shape.dimensions_size() == 3;
929 llvm::Value* lhs_address = GetEmittedValueFor(lhs);
930 llvm::Value* rhs_address = GetEmittedValueFor(rhs);
931 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(convolution));
932
933 const ConvolutionDimensionNumbers& dnums =
934 convolution->convolution_dimension_numbers();
935
936 absl::InlinedVector<int64_t, 2> input_dims;
937 absl::InlinedVector<int64_t, 2> kernel_dims;
938 absl::InlinedVector<int64_t, 2> output_dims;
939 if (one_dim_convolution) {
940 input_dims.push_back(1);
941 kernel_dims.push_back(1);
942 output_dims.push_back(1);
943 }
944
945 // Input tensor.
946 const Shape& input_shape = convolution->operand(0)->shape();
947 int64_t input_batch =
948 input_shape.dimensions(dnums.input_batch_dimension());
949 for (int d : dnums.input_spatial_dimensions()) {
950 input_dims.push_back(input_shape.dimensions(d));
951 }
952 int64_t input_channels =
953 input_shape.dimensions(dnums.input_feature_dimension());
954
955 // Kernel tensor.
956 const Shape& kernel_shape = convolution->operand(1)->shape();
957 for (int d : dnums.kernel_spatial_dimensions()) {
958 kernel_dims.push_back(kernel_shape.dimensions(d));
959 }
960 int64_t kernel_channels =
961 kernel_shape.dimensions(dnums.kernel_input_feature_dimension());
962 int64_t kernel_filters =
963 kernel_shape.dimensions(dnums.kernel_output_feature_dimension());
964
965 // Output tensor.
966 const Shape& convolution_shape = convolution->shape();
967 for (int d : dnums.output_spatial_dimensions()) {
968 output_dims.push_back(convolution_shape.dimensions(d));
969 }
970
971 // Extract the window stride for the convolution.
972 const Window& window = convolution->window();
973 absl::InlinedVector<int64_t, 2> strides;
974 absl::InlinedVector<std::pair<int64_t, int64_t>, 2> padding;
975 absl::InlinedVector<int64_t, 2> base_dilation;
976 absl::InlinedVector<int64_t, 2> window_dilation;
977 if (one_dim_convolution) {
978 strides.push_back(1);
979 padding.push_back({0, 0});
980 base_dilation.push_back(1);
981 window_dilation.push_back(1);
982 }
983 for (const auto& d : window.dimensions()) {
984 strides.push_back(d.stride());
985 padding.push_back({d.padding_low(), d.padding_high()});
986 base_dilation.push_back(d.base_dilation());
987 window_dilation.push_back(d.window_dilation());
988 }
989
990 PrimitiveType primitive_type = lhs->shape().element_type();
991 llvm::Type* ir_ptr_type = primitive_type == F16
992 ? b_.getHalfTy()->getPointerTo()
993 : b_.getFloatTy()->getPointerTo();
994 bool multi_threaded =
995 hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
996 bool use_mkl_dnn =
997 hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn() &&
998 convolution->feature_group_count() == 1;
999 bool use_acl = hlo_module_config_.debug_options().xla_cpu_use_acl();
1000
1001 auto valid_num_dims = [](absl::Span<const int64_t> xs) {
1002 return xs.size() >= 2 && xs.size() <= 3;
1003 };
1004 TF_RET_CHECK(valid_num_dims(input_dims)) << input_dims.size();
1005 TF_RET_CHECK(valid_num_dims(kernel_dims));
1006 TF_RET_CHECK(valid_num_dims(output_dims));
1007 TF_RET_CHECK(valid_num_dims(strides));
1008 TF_RET_CHECK(padding.size() >= 2 && padding.size() <= 3);
1009 TF_RET_CHECK(valid_num_dims(base_dilation));
1010 TF_RET_CHECK(valid_num_dims(window_dilation));
1011
1012 // TODO(b/78639006) Singlethread MKL conv2d is not implemented due to the
1013 // potential race condition by setting the omp_num_threads.
1014 const char* fn_name;
1015 if (input_dims.size() == 2) {
1016 fn_name =
1017 primitive_type == F16
1018 ? (multi_threaded
1019 ? runtime::kEigenConv2DF16SymbolName
1020 : runtime::kEigenSingleThreadedConv2DF16SymbolName)
1021 : (multi_threaded
1022 ? (use_mkl_dnn
1023 ? runtime::kMKLConv2DF32SymbolName
1024 : (use_acl ? runtime::kACLConv2DF32SymbolName
1025 : runtime::kEigenConv2DF32SymbolName))
1026 : runtime::kEigenSingleThreadedConv2DF32SymbolName);
1027 } else if (input_dims.size() == 3) {
1028 fn_name =
1029 primitive_type == F16
1030 ? (multi_threaded
1031 ? runtime::kEigenConv3DF16SymbolName
1032 : runtime::kEigenSingleThreadedConv3DF16SymbolName)
1033 : (multi_threaded
1034 ? runtime::kEigenConv3DF32SymbolName
1035 : runtime::kEigenSingleThreadedConv3DF32SymbolName);
1036 } else {
1037 LOG(FATAL) << "Invalid number of dimensions " << input_dims.size();
1038 }
1039 if (!multi_threaded && use_mkl_dnn) {
1040 LOG(WARNING) << "Using Eigen instead of MKL-DNN for single-threaded "
1041 "convolution.";
1042 }
1043 std::vector<llvm::Value*> args = {
1044 GetExecutableRunOptionsArgument(),
1045 BitCast(GetEmittedValueFor(convolution), ir_ptr_type),
1046 BitCast(lhs_address, ir_ptr_type),
1047 BitCast(rhs_address, ir_ptr_type),
1048 b_.getInt64(input_batch),
1049 };
1050 for (int64_t d : input_dims) {
1051 args.push_back(b_.getInt64(d));
1052 }
1053 args.push_back(b_.getInt64(input_channels));
1054 for (int64_t d : kernel_dims) {
1055 args.push_back(b_.getInt64(d));
1056 }
1057 args.push_back(b_.getInt64(kernel_channels));
1058 args.push_back(b_.getInt64(kernel_filters));
1059 for (int64_t d : output_dims) {
1060 args.push_back(b_.getInt64(d));
1061 }
1062 for (int64_t d : strides) {
1063 args.push_back(b_.getInt64(d));
1064 }
1065 for (const auto& p : padding) {
1066 args.push_back(b_.getInt64(p.first));
1067 args.push_back(b_.getInt64(p.second));
1068 }
1069 for (int64_t d : base_dilation) {
1070 args.push_back(b_.getInt64(d));
1071 }
1072 for (int64_t d : window_dilation) {
1073 args.push_back(b_.getInt64(d));
1074 }
1075 args.push_back(b_.getInt64(convolution->feature_group_count()));
1076
1077 VLOG(1) << "Ir emitter emitted Convolution to runtime:" << fn_name;
1078 EmitCallToFunc(fn_name, args, b_.getVoidTy(), /*does_not_throw=*/true,
1079 /*only_accesses_arg_memory=*/true);
1080
1081 return OkStatus();
1082 }
1083 }
1084 // This is a completely un-optimized version of convolution just to
1085 // have an early version that works. E.g. the input index and
1086 // padding calculation is not hoisted out of the inner loop.
1087 //
1088 // See the description of convolution in the XLA documentation for the pseudo
1089 // code for convolution.
1090 return DefaultAction(convolution);
1091 }
1092
HandleFft(HloInstruction * fft)1093 Status IrEmitter::HandleFft(HloInstruction* fft) {
1094 auto operand = fft->operand(0);
1095 TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
1096 /*instruction=*/*fft, /*operands=*/{operand},
1097 /*supported_types=*/{F32, F64, C64, C128}));
1098 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout()));
1099 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout()));
1100 VLOG(3) << "operand=" << ShapeUtil::HumanStringWithLayout(operand->shape());
1101 VLOG(3) << "fft=" << ShapeUtil::HumanStringWithLayout(fft->shape());
1102
1103 llvm::Value* operand_address = GetEmittedValueFor(operand);
1104 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fft));
1105
1106 const std::vector<int64_t>& fft_length = fft->fft_length();
1107 int64_t input_batch = 1;
1108 for (int i = 0; i < fft->shape().dimensions_size() - fft_length.size(); i++) {
1109 input_batch *= fft->shape().dimensions(i);
1110 }
1111
1112 // Args have been computed, make the call.
1113 llvm::Type* int8_ptr_type = b_.getInt8Ty()->getPointerTo();
1114 bool multi_threaded_eigen =
1115 hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
1116 const char* fn_name = multi_threaded_eigen
1117 ? runtime::kEigenFftSymbolName
1118 : runtime::kEigenSingleThreadedFftSymbolName;
1119 const int fft_rank = fft_length.size();
1120 EmitCallToFunc(
1121 fn_name,
1122 {GetExecutableRunOptionsArgument(),
1123 BitCast(GetEmittedValueFor(fft), int8_ptr_type),
1124 BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()),
1125 b_.getInt32(operand->shape().element_type() == F64 ||
1126 operand->shape().element_type() == C128),
1127 b_.getInt32(fft_rank), b_.getInt64(input_batch),
1128 b_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
1129 b_.getInt64(fft_rank > 1 ? fft_length[1] : 0),
1130 b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)},
1131 b_.getVoidTy(), /*does_not_throw=*/true,
1132 /*only_accesses_arg_memory=*/false,
1133 /*only_accesses_inaccessible_mem_or_arg_mem=*/true);
1134
1135 return OkStatus();
1136 }
1137
HandleAllReduceSingleReplica(HloInstruction * crs)1138 Status IrEmitter::HandleAllReduceSingleReplica(HloInstruction* crs) {
1139 // When there is a single replica, a cross replica sum is the identity
1140 // function, and the buffer assignment expects a copy.
1141 //
1142 // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely
1143 // in algebraic-simplifier, but currently on some platforms
1144 // HloModuleConfig::num_replicas changes between when the module is compiled
1145 // and when it's run.
1146 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs));
1147
1148 // CRS with one operand and one replica is simply the identity function.
1149 if (crs->operand_count() == 1) {
1150 return EmitMemcpy(*crs->operand(0), *crs);
1151 }
1152
1153 // CRS with multiple operands and one replica produces a (one-deep) tuple.
1154 std::vector<llvm::Value*> operand_ptrs;
1155 for (int64_t i = 0; i < crs->operand_count(); ++i) {
1156 llvm::Value* in_ptr = GetEmittedValueFor(crs->operand(i));
1157 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
1158 assignment_.GetUniqueSlice(crs, {i}));
1159
1160 const Shape& operand_shape = crs->operand(i)->shape();
1161 CHECK(operand_shape.IsArray())
1162 << "Operands to all-reduce must be arrays: " << crs->ToString();
1163 operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
1164
1165 // TODO(b/63762267): Be more aggressive about specifying alignment.
1166 MemCpy(operand_ptrs.back(), /*DstAlign=*/llvm::Align(1), in_ptr,
1167 /*SrcAlign=*/llvm::Align(1), ShapeUtil::ByteSizeOf(operand_shape));
1168 }
1169 llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_);
1170 return OkStatus();
1171 }
1172
HandleAllReduceMultipleReplica(HloInstruction * crs)1173 Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) {
1174 CHECK_GE(crs->operand_count(), 1);
1175 PrimitiveType datatype = crs->operand(0)->shape().element_type();
1176 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs));
1177
1178 bool is_datatype_supported = [&] {
1179 // TODO(cheshire): Fix duplication wrt. cpu_runtime
1180 switch (datatype) {
1181 case PRED:
1182 case S8:
1183 case U8:
1184 case S32:
1185 case U32:
1186 case S64:
1187 case U64:
1188 case F16:
1189 case F32:
1190 case F64:
1191 case C64:
1192 case C128:
1193 return true;
1194 default:
1195 return false;
1196 }
1197 }();
1198
1199 if (!is_datatype_supported) {
1200 return Unimplemented("AllReduce for datatype '%s' is not supported",
1201 primitive_util::LowercasePrimitiveTypeName(datatype));
1202 }
1203
1204 if (!MatchReductionComputation(crs->to_apply()).has_value()) {
1205 return Unimplemented("AllReduce for computation '%s' is not supported",
1206 crs->to_apply()->ToString());
1207 }
1208
1209 std::string replica_groups = ReplicaGroupsToString(crs->replica_groups());
1210 int32_t replica_groups_size = replica_groups.size();
1211 llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups);
1212
1213 bool is_tuple = crs->operand_count() > 1;
1214 std::vector<llvm::Value*> input_buffer_ptrs;
1215 std::vector<llvm::Value*> output_buffer_ptrs;
1216 if (is_tuple) {
1217 CHECK(crs->shape().IsTuple());
1218
1219 for (int64_t i = 0; i < crs->operand_count(); i++) {
1220 const HloInstruction* op = crs->operand(i);
1221 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
1222 assignment_.GetUniqueSlice(crs, {i}));
1223 const Shape& operand_shape = crs->operand(i)->shape();
1224 CHECK(operand_shape.IsArray())
1225 << "Operands to all-reduce must be arrays: " << crs->ToString();
1226 output_buffer_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
1227 input_buffer_ptrs.push_back(GetEmittedValueFor(op));
1228 }
1229 } else {
1230 Shape shape = crs->operand(0)->shape();
1231 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
1232 assignment_.GetUniqueSlice(crs->operand(0), {}));
1233 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1234 assignment_.GetUniqueSlice(crs, {}));
1235 input_buffer_ptrs.push_back(EmitBufferPointer(input_slice, shape));
1236 output_buffer_ptrs.push_back(EmitBufferPointer(output_slice, shape));
1237 }
1238
1239 llvm::Value* input_buffers =
1240 EncodeArrayFunctionArguments(input_buffer_ptrs, "input_buffers", &b_);
1241 llvm::Value* output_buffers =
1242 EncodeArrayFunctionArguments(output_buffer_ptrs, "output_buffers", &b_);
1243
1244 int32_t shape_length;
1245 TF_ASSIGN_OR_RETURN(llvm::Value * shape_ptr,
1246 llvm_ir::EncodeSelfDescribingShapeConstant(
1247 crs->shape(), &shape_length, &b_));
1248
1249 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1250 bool use_global_device_ids =
1251 Cast<HloAllReduceInstruction>(crs)->use_global_device_ids();
1252 EmitCallToFunc(
1253 runtime::kAllReduceSymbolName,
1254 {/*run_options=*/GetExecutableRunOptionsArgument(),
1255 /*replica_groups=*/replica_groups_v,
1256 /*replica_groups_size=*/b_.getInt32(replica_groups_size),
1257
1258 /*channel_id_present=*/
1259 b_.getInt32(static_cast<int32_t>(crs->channel_id().has_value())),
1260 /*use_global_device_ids=*/
1261 b_.getInt32(static_cast<int32_t>(use_global_device_ids)),
1262 /*op_id=*/
1263 b_.getInt64(crs->channel_id().has_value()
1264 ? *crs->channel_id()
1265 : crs->GetModule()->unique_id()),
1266 /*reduction_kind=*/
1267 b_.getInt32(
1268 static_cast<int32_t>(*MatchReductionComputation(crs->to_apply()))),
1269 /*shape_ptr=*/shape_ptr,
1270 /*shape_length=*/b_.getInt32(shape_length),
1271 /*num_buffers=*/b_.getInt32(crs->operand_count()),
1272 /*input_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type),
1273 /*output_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)},
1274 b_.getVoidTy());
1275
1276 return OkStatus();
1277 }
1278
HandleAllReduce(HloInstruction * crs)1279 Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
1280 if (hlo_module_config_.replica_count() == 1 &&
1281 hlo_module_config_.num_partitions() == 1) {
1282 return HandleAllReduceSingleReplica(crs);
1283 }
1284 return HandleAllReduceMultipleReplica(crs);
1285 }
1286
HandleAllToAll(HloInstruction * instruction)1287 Status IrEmitter::HandleAllToAll(HloInstruction* instruction) {
1288 auto* instr = Cast<HloAllToAllInstruction>(instruction);
1289 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction));
1290 CHECK(!instr->split_dimension() && instr->shape().IsTuple())
1291 << "Only tuple AllToAll is supported";
1292
1293 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1294 std::string replica_groups =
1295 ReplicaGroupsToString(instruction->replica_groups());
1296 int32_t replica_groups_size = replica_groups.size();
1297 llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups);
1298
1299 int64_t buffer_size = -1;
1300 std::vector<llvm::Value*> input_buffer_ptrs;
1301 std::vector<llvm::Value*> output_buffer_ptrs;
1302
1303 for (int64_t i = 0; i < instruction->operand_count(); i++) {
1304 const HloInstruction* op = instruction->operand(i);
1305 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
1306 assignment_.GetUniqueSlice(instruction, {i}));
1307 const Shape& operand_shape = instruction->operand(i)->shape();
1308 CHECK(operand_shape.IsArray())
1309 << "Operands to all-to-all must be arrays: " << instruction->ToString();
1310 output_buffer_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
1311 input_buffer_ptrs.push_back(GetEmittedValueFor(op));
1312 CHECK(buffer_size == -1 || buffer_size == out_slice.size());
1313 buffer_size = out_slice.size();
1314 }
1315
1316 llvm::Value* input_buffers =
1317 EncodeArrayFunctionArguments(input_buffer_ptrs, "input_buffers", &b_);
1318 llvm::Value* output_buffers =
1319 EncodeArrayFunctionArguments(output_buffer_ptrs, "output_buffers", &b_);
1320
1321 EmitCallToFunc(
1322 runtime::kAllToAllSymbolName,
1323 {/*run_options=*/GetExecutableRunOptionsArgument(),
1324 /*channel_id_present=*/
1325 b_.getInt32(static_cast<int32_t>(instruction->channel_id().has_value())),
1326 /*op_id=*/
1327 b_.getInt64(instruction->channel_id().has_value()
1328 ? *instruction->channel_id()
1329 : instruction->GetModule()->unique_id()),
1330 /*replica_groups=*/replica_groups_v,
1331 /*replica_groups_size=*/b_.getInt32(replica_groups_size),
1332 /*num_buffers=*/b_.getInt32(instruction->operand_count()),
1333 /*buffer_size=*/b_.getInt64(buffer_size),
1334 /*source_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type),
1335 /*destination_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)},
1336 b_.getVoidTy());
1337
1338 llvm_ir::EmitTuple(GetIrArrayFor(instruction), output_buffer_ptrs, &b_);
1339 return OkStatus();
1340 }
1341
HandleCollectivePermute(HloInstruction * crs)1342 Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) {
1343 auto* instr = Cast<HloCollectivePermuteInstruction>(crs);
1344 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instr));
1345 std::string source_target_pairs = absl::StrJoin(
1346 instr->source_target_pairs(), ",", absl::PairFormatter("="));
1347 llvm::Value* source_target_pairs_v =
1348 b_.CreateGlobalStringPtr(source_target_pairs);
1349
1350 Shape shape = crs->operand(0)->shape();
1351
1352 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
1353 assignment_.GetUniqueSlice(crs->operand(0), {}));
1354 llvm::Value* input_buffer = EmitBufferPointer(input_slice, shape);
1355
1356 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1357 assignment_.GetUniqueSlice(crs, {}));
1358 llvm::Value* output_buffer = EmitBufferPointer(output_slice, shape);
1359
1360 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1361 EmitCallToFunc(
1362 runtime::kCollectivePermuteSymbolName,
1363 {/*run_options=*/GetExecutableRunOptionsArgument(),
1364 /*channel_id_present=*/
1365 b_.getInt32(static_cast<int32_t>(crs->channel_id().has_value())),
1366 /*op_id=*/
1367 b_.getInt64(crs->channel_id().has_value()
1368 ? *crs->channel_id()
1369 : crs->GetModule()->unique_id()),
1370 /*byte_size=*/b_.getInt32(ShapeUtil::ByteSizeOf(shape)),
1371 /*input_buffer=*/b_.CreateBitCast(input_buffer, i8_ptr_type),
1372 /*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type),
1373 /*source_target_pairs=*/source_target_pairs_v,
1374 /*source_target_pairs_size=*/b_.getInt32(source_target_pairs.size())},
1375 b_.getVoidTy());
1376
1377 return OkStatus();
1378 }
1379
HandlePartitionId(HloInstruction * hlo)1380 Status IrEmitter::HandlePartitionId(HloInstruction* hlo) {
1381 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
1382 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1383 assignment_.GetUniqueSlice(hlo, {}));
1384 llvm::Value* output_buffer = EmitBufferPointer(output_slice, hlo->shape());
1385 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1386 EmitCallToFunc(
1387 runtime::kPartitionIdSymbolName,
1388 {/*run_options=*/GetExecutableRunOptionsArgument(),
1389 /*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type)},
1390 b_.getVoidTy());
1391 return OkStatus();
1392 }
1393
HandleReplicaId(HloInstruction * hlo)1394 Status IrEmitter::HandleReplicaId(HloInstruction* hlo) {
1395 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
1396 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1397 assignment_.GetUniqueSlice(hlo, {}));
1398 llvm::Value* output_buffer = EmitBufferPointer(output_slice, hlo->shape());
1399 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1400 EmitCallToFunc(
1401 runtime::kReplicaIdSymbolName,
1402 {/*run_options=*/GetExecutableRunOptionsArgument(),
1403 /*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type)},
1404 b_.getVoidTy());
1405 return OkStatus();
1406 }
1407
HandleParameter(HloInstruction * parameter)1408 Status IrEmitter::HandleParameter(HloInstruction* parameter) {
1409 VLOG(2) << "HandleParameter: " << parameter->ToString();
1410 return EmitTargetAddressForOp(parameter);
1411 }
1412
1413 // Returns true if the relative order of the unreduced dimensions stays the same
1414 // through the reduce operation.
ReductionPreservesLayout(const HloInstruction & reduce)1415 static bool ReductionPreservesLayout(const HloInstruction& reduce) {
1416 DCHECK_EQ(reduce.opcode(), HloOpcode::kReduce);
1417
1418 // Maps dimensions that were not reduced from their dimension numbers in the
1419 // source shape to their dimensions numbers in the destination shape.
1420 //
1421 // So if we reduce f32[A,B,C,D] on dimensions 1 and 2, this map contains
1422 // [0->0, 3->1].
1423 absl::flat_hash_map<int64_t, int64_t> unreduced_dim_map;
1424
1425 absl::flat_hash_set<int64_t> reduced_dims(reduce.dimensions().begin(),
1426 reduce.dimensions().end());
1427
1428 const Shape& operand_shape = reduce.operand(0)->shape();
1429 const Shape& result_shape = reduce.shape();
1430
1431 int64_t delta = 0;
1432 for (int64_t i = 0; i < operand_shape.dimensions_size(); i++) {
1433 if (reduced_dims.contains(i)) {
1434 delta++;
1435 } else {
1436 InsertOrDie(&unreduced_dim_map, i, i - delta);
1437 }
1438 }
1439
1440 // Iterate dimensions minor to major and check that the corresponding
1441 // dimensions in the source and target shapes are equivalent.
1442 int64_t result_dim_idx = 0;
1443 for (int64_t operand_dim_idx = 0;
1444 operand_dim_idx < operand_shape.dimensions_size(); operand_dim_idx++) {
1445 int64_t operand_dim =
1446 operand_shape.layout().minor_to_major(operand_dim_idx);
1447 if (!reduced_dims.contains(operand_dim)) {
1448 if (FindOrDie(unreduced_dim_map, operand_dim) !=
1449 result_shape.layout().minor_to_major(result_dim_idx++)) {
1450 return false;
1451 }
1452 }
1453 }
1454
1455 CHECK_EQ(result_dim_idx, result_shape.dimensions_size());
1456
1457 return true;
1458 }
1459
MatchReductionGenerator(HloComputation * function,std::string * failure_reason) const1460 IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator(
1461 HloComputation* function, std::string* failure_reason) const {
1462 CHECK_EQ(function->num_parameters(), 2);
1463
1464 auto root_instruction = function->root_instruction();
1465 CHECK(ShapeUtil::IsScalar(root_instruction->shape()));
1466
1467 if (root_instruction->operand_count() != 2) {
1468 *failure_reason = "root instruction is not a binary operation";
1469 return nullptr;
1470 }
1471
1472 const Shape& root_shape = root_instruction->shape();
1473 if (ShapeUtil::ElementIsComplex(root_shape)) {
1474 // TODO(b/65408531): Complex add could by done via bitcast to <float x [2N]>
1475 // Complex multiply would be more challenging. We could perhaps use a
1476 // strided load to get all reals in a vector, all images in a vector, or use
1477 // CreateShuffleVector on a bitcast to float x [2N].
1478 *failure_reason = "complex values not supported";
1479 return nullptr;
1480 }
1481 bool root_is_floating_point = ShapeUtil::ElementIsFloating(root_shape);
1482 bool root_is_integral = ShapeUtil::ElementIsIntegral(root_shape);
1483 bool root_is_signed = ShapeUtil::ElementIsSigned(root_shape);
1484
1485 auto lhs = root_instruction->operand(0);
1486 auto rhs = root_instruction->operand(1);
1487
1488 auto param_0 = function->parameter_instruction(0);
1489 auto param_1 = function->parameter_instruction(1);
1490 if (!(lhs == param_0 && rhs == param_1) &&
1491 !(rhs == param_0 && lhs == param_1)) {
1492 *failure_reason =
1493 "root instruction is not a binary operation on the incoming arguments";
1494 return nullptr;
1495 }
1496
1497 CHECK(ShapeUtil::IsScalar(lhs->shape()) && ShapeUtil::IsScalar(rhs->shape()));
1498
1499 // This is visually similar to ElementalIrEmitter, though conceptually we're
1500 // doing something different here. ElementalIrEmitter emits scalar operations
1501 // while these emit scalar or vector operations depending on the type of the
1502 // operands. See CreateShardedVectorType for the actual types in use here.
1503 switch (root_instruction->opcode()) {
1504 default:
1505 *failure_reason = "did not recognize root instruction opcode";
1506 return nullptr;
1507
1508 case HloOpcode::kAdd:
1509 return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs,
1510 llvm::Value* rhs) {
1511 return root_is_integral ? b->CreateAdd(lhs, rhs)
1512 : b->CreateFAdd(lhs, rhs);
1513 };
1514
1515 case HloOpcode::kMultiply:
1516 return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs,
1517 llvm::Value* rhs) {
1518 return root_is_integral ? b->CreateMul(lhs, rhs)
1519 : b->CreateFMul(lhs, rhs);
1520 };
1521
1522 case HloOpcode::kAnd:
1523 return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1524 return b->CreateAnd(lhs, rhs);
1525 };
1526
1527 case HloOpcode::kOr:
1528 return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1529 return b->CreateOr(lhs, rhs);
1530 };
1531
1532 case HloOpcode::kXor:
1533 return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1534 return b->CreateXor(lhs, rhs);
1535 };
1536
1537 case HloOpcode::kMaximum:
1538 return [root_is_floating_point, root_is_signed, this](
1539 llvm::IRBuilder<>* b, llvm::Value* lhs,
1540 llvm::Value* rhs) -> llvm::Value* {
1541 if (root_is_floating_point) {
1542 return llvm_ir::EmitFloatMax(
1543 lhs, rhs, b,
1544 hlo_module_config_.debug_options().xla_cpu_enable_fast_min_max());
1545 }
1546
1547 return b->CreateSelect(
1548 b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SGE
1549 : llvm::ICmpInst::ICMP_UGE,
1550 lhs, rhs),
1551 lhs, rhs);
1552 };
1553
1554 case HloOpcode::kMinimum:
1555 return [root_is_floating_point, root_is_signed, this](
1556 llvm::IRBuilder<>* b, llvm::Value* lhs,
1557 llvm::Value* rhs) -> llvm::Value* {
1558 if (root_is_floating_point) {
1559 return llvm_ir::EmitFloatMin(
1560 lhs, rhs, b,
1561 hlo_module_config_.debug_options().xla_cpu_enable_fast_min_max());
1562 }
1563
1564 return b->CreateSelect(
1565 b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SLE
1566 : llvm::ICmpInst::ICMP_ULE,
1567 lhs, rhs),
1568 lhs, rhs);
1569 };
1570 }
1571 }
1572
CreateShardedVectorType(PrimitiveType element_type,unsigned element_count)1573 IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType(
1574 PrimitiveType element_type, unsigned element_count) {
1575 int vector_register_size_in_elements =
1576 target_machine_features_.vector_register_byte_size(
1577 *compute_function_->function()) /
1578 ShapeUtil::ByteSizeOfPrimitiveType(element_type);
1579
1580 ShardedVectorType sharded_vector_type;
1581 llvm::Type* element_ir_type =
1582 llvm_ir::PrimitiveTypeToIrType(element_type, module_);
1583
1584 for (int i = 0, e = 1 + Log2Ceiling(element_count); i < e; i++) {
1585 // For every power of two present in element_count, we generate one or more
1586 // vector or scalar types.
1587 const unsigned current_size_fragment = 1u << i;
1588 if (!(element_count & current_size_fragment)) {
1589 // Power of two not present in element_count.
1590 continue;
1591 }
1592
1593 if (current_size_fragment == 1) {
1594 // Single element, use a scalar type.
1595 sharded_vector_type.push_back(element_ir_type);
1596 continue;
1597 }
1598
1599 // Lower "current_size_fragment" number of elements using (as few as
1600 // possible) vector registers.
1601
1602 if (current_size_fragment >= vector_register_size_in_elements) {
1603 auto vector_type = llvm::VectorType::get(
1604 element_ir_type, vector_register_size_in_elements, false);
1605 sharded_vector_type.insert(
1606 sharded_vector_type.end(),
1607 current_size_fragment / vector_register_size_in_elements,
1608 vector_type);
1609
1610 // Both current_size_fragment and vector_register_size_in_elements are
1611 // powers of two.
1612 CHECK_EQ(current_size_fragment % vector_register_size_in_elements, 0);
1613 continue;
1614 }
1615
1616 // For now we assume that vector_register_size_in_elements and lower powers
1617 // of two are all legal vector sizes (or at least can be lowered easily by
1618 // LLVM).
1619 sharded_vector_type.push_back(
1620 llvm::VectorType::get(element_ir_type, current_size_fragment, false));
1621 }
1622 return sharded_vector_type;
1623 }
1624
1625 StatusOr<IrEmitter::ShardedVector>
EmitInnerLoopForVectorizedReduction(const ReductionGenerator & reduction_generator,const llvm_ir::IrArray::Index & output_index,const ShardedVectorType & accumulator_type,HloInstruction * init_value,HloInstruction * arg,absl::Span<const int64_t> dimensions,llvm::Align element_alignment)1626 IrEmitter::EmitInnerLoopForVectorizedReduction(
1627 const ReductionGenerator& reduction_generator,
1628 const llvm_ir::IrArray::Index& output_index,
1629 const ShardedVectorType& accumulator_type, HloInstruction* init_value,
1630 HloInstruction* arg, absl::Span<const int64_t> dimensions,
1631 llvm::Align element_alignment) {
1632 ShardedVector accumulator;
1633 accumulator.reserve(accumulator_type.size());
1634 for (auto accumulator_shard_type : accumulator_type) {
1635 accumulator.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
1636 accumulator_shard_type, "accumulator", &b_, 0));
1637 }
1638
1639 llvm::Value* init_value_ssa =
1640 Load(IrShapeType(init_value->shape()), GetEmittedValueFor(init_value));
1641
1642 for (llvm::Value* accumulator_shard : accumulator) {
1643 llvm::Value* initial_value;
1644 auto shard_type =
1645 llvm::cast<llvm::AllocaInst>(accumulator_shard)->getAllocatedType();
1646 if (auto vector_type = llvm::dyn_cast<llvm::VectorType>(shard_type)) {
1647 initial_value =
1648 VectorSplat(vector_type->getElementCount(), init_value_ssa);
1649 } else {
1650 initial_value = init_value_ssa;
1651 }
1652
1653 AlignedStore(initial_value, accumulator_shard, element_alignment);
1654 }
1655
1656 llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"),
1657 &b_);
1658 std::vector<llvm::Value*> input_multi_index =
1659 reduction_loop_nest.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
1660 "reduction_dim");
1661
1662 SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(), &b_);
1663
1664 llvm_ir::IrArray arg_array(GetIrArrayFor(arg));
1665 llvm_ir::IrArray::Index::const_iterator it = output_index.begin();
1666
1667 for (auto& i : input_multi_index) {
1668 if (i == nullptr) {
1669 i = *it++;
1670 }
1671 }
1672 CHECK(output_index.end() == it);
1673 llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
1674 b_.getInt64Ty());
1675
1676 llvm::Value* input_address = BitCast(
1677 arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy());
1678
1679 for (int i = 0; i < accumulator.size(); i++) {
1680 auto input_address_typed =
1681 BitCast(input_address, accumulator[i]->getType());
1682 auto alloca = llvm::cast<llvm::AllocaInst>(accumulator[i]);
1683 auto current_accumulator_value = AlignedLoad(
1684 alloca->getAllocatedType(), accumulator[i], element_alignment);
1685 auto addend = AlignedLoad(alloca->getAllocatedType(), input_address_typed,
1686 element_alignment);
1687 arg_array.AnnotateLoadStoreInstructionWithMetadata(addend);
1688
1689 auto reduced_result =
1690 reduction_generator(&b_, current_accumulator_value, addend);
1691 AlignedStore(reduced_result, accumulator[i], element_alignment);
1692
1693 if (i != (accumulator.size() - 1)) {
1694 input_address = ConstInBoundsGEP1_32(reduced_result->getType(),
1695 input_address_typed, 1);
1696 }
1697 }
1698
1699 SetToFirstInsertPoint(reduction_loop_nest.GetOuterLoopExitBasicBlock(), &b_);
1700
1701 ShardedVector result_ssa;
1702 result_ssa.reserve(accumulator.size());
1703 for (auto accumulator_shard : accumulator) {
1704 auto alloca = llvm::cast<llvm::AllocaInst>(accumulator_shard);
1705 result_ssa.push_back(AlignedLoad(alloca->getAllocatedType(),
1706 accumulator_shard, element_alignment));
1707 }
1708 return result_ssa;
1709 }
1710
EmitShardedVectorStore(llvm::Value * store_address,const std::vector<llvm::Value * > & value_to_store,llvm::Align alignment,const llvm_ir::IrArray & containing_array)1711 void IrEmitter::EmitShardedVectorStore(
1712 llvm::Value* store_address, const std::vector<llvm::Value*>& value_to_store,
1713 llvm::Align alignment, const llvm_ir::IrArray& containing_array) {
1714 for (int i = 0; i < value_to_store.size(); i++) {
1715 auto store_address_typed =
1716 BitCast(store_address,
1717 llvm::PointerType::getUnqual(value_to_store[i]->getType()));
1718
1719 auto store_instruction =
1720 AlignedStore(value_to_store[i], store_address_typed, alignment);
1721 containing_array.AnnotateLoadStoreInstructionWithMetadata(
1722 store_instruction);
1723
1724 if (i != (value_to_store.size() - 1)) {
1725 store_address = ConstInBoundsGEP1_32(value_to_store[i]->getType(),
1726 store_address_typed, 1);
1727 }
1728 }
1729 }
1730
EmitVectorizedReduce(HloInstruction * reduce,HloInstruction * arg,HloInstruction * init_value,absl::Span<const int64_t> dimensions,HloComputation * function,std::string * failure_reason)1731 StatusOr<bool> IrEmitter::EmitVectorizedReduce(
1732 HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
1733 absl::Span<const int64_t> dimensions, HloComputation* function,
1734 std::string* failure_reason) {
1735 if (!reduce->shape().IsArray()) {
1736 *failure_reason = "vectorization of variadic reduce not implemented";
1737 return false;
1738 }
1739
1740 if (!ReductionPreservesLayout(*reduce)) {
1741 return false;
1742 }
1743
1744 ReductionGenerator reduction_generator =
1745 MatchReductionGenerator(function, failure_reason);
1746 if (!reduction_generator) {
1747 return false;
1748 }
1749
1750 int vector_register_size_in_elements =
1751 target_machine_features_.vector_register_byte_size(
1752 *compute_function_->function()) /
1753 ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type());
1754 if (vector_register_size_in_elements == 0) {
1755 // Either we don't know the vector register width for the target or the
1756 // vector register is smaller than the size of the primitive type.
1757 return false;
1758 }
1759
1760 int vectorization_factor_in_bytes =
1761 target_machine_features_.vectorization_factor_in_bytes();
1762
1763 // We try to process vectorization_factor elements at the same time.
1764 const int vectorization_factor =
1765 vectorization_factor_in_bytes /
1766 ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type());
1767
1768 bool is_reduction_over_minor_dimension = absl::c_linear_search(
1769 dimensions, LayoutUtil::Minor(arg->shape().layout(), 0));
1770
1771 llvm::Align element_alignment(tensorflow::MathUtil::GCD<unsigned>(
1772 ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()),
1773 MinimumAlignmentForPrimitiveType(reduce->shape().element_type())));
1774
1775 if (is_reduction_over_minor_dimension) {
1776 // TODO(sanjoy): Implement vectorized reduction over the minor dimension.
1777 *failure_reason = "reduction over minor dimension not implemented";
1778 return false;
1779 }
1780
1781 CHECK(!reduce->shape().IsTuple());
1782 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(reduce));
1783
1784 // We know we're not reducing over the most minor dimension, which means we
1785 // can lower the reduction loop as:
1786 //
1787 // 1. We're reducing over dimensions R0, R1.
1788 // 2. D0 is the most minor dimension.
1789 // 3. VS is the vectorization stride (we want to reduce this many elements at
1790 // once)
1791 //
1792 // for (d1 in D1) {
1793 // for (d0 in D0 with stride VS) {
1794 // vector_acc = init
1795 // for (r1 in R1) {
1796 // for (r0 in R0) {
1797 // vector_acc = elementwise_reduce(vector_acc, input[d1, d0, r1, r0]
1798 // }
1799 // }
1800 // output[d1, d0] = vector_acc
1801 // }
1802 // }
1803
1804 llvm_ir::ForLoopNest loop_nest(IrName(reduce), &b_);
1805 std::vector<llvm::Value*> array_multi_index(
1806 reduce->shape().dimensions_size());
1807 for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0;
1808 --i) {
1809 int64_t dimension = LayoutUtil::Minor(reduce->shape().layout(), i);
1810 int64_t start_index = 0;
1811 int64_t end_index = reduce->shape().dimensions(dimension);
1812 std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop(
1813 start_index, end_index, absl::StrFormat("dim.%d", dimension));
1814 array_multi_index[dimension] = loop->GetIndVarValue();
1815 }
1816
1817 int64_t innermost_dimension = LayoutUtil::Minor(reduce->shape().layout(), 0);
1818 int64_t innermost_dimension_size =
1819 reduce->shape().dimensions(innermost_dimension);
1820
1821 if (llvm::BasicBlock* innermost_body_bb =
1822 loop_nest.GetInnerLoopBodyBasicBlock()) {
1823 SetToFirstInsertPoint(innermost_body_bb, &b_);
1824 }
1825
1826 auto outermost_loop_exit_block = loop_nest.GetOuterLoopExitBasicBlock();
1827
1828 if (innermost_dimension_size >= vectorization_factor) {
1829 int64_t start_index = 0;
1830 int64_t end_index = (innermost_dimension_size / vectorization_factor) *
1831 vectorization_factor;
1832 std::unique_ptr<llvm_ir::ForLoop> loop =
1833 loop_nest.AddLoop(start_index, end_index, vectorization_factor,
1834 absl::StrFormat("dim.%d", innermost_dimension));
1835 array_multi_index[innermost_dimension] = loop->GetIndVarValue();
1836
1837 SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &b_);
1838
1839 ShardedVectorType vector_type = CreateShardedVectorType(
1840 reduce->shape().element_type(), vectorization_factor);
1841 llvm_ir::IrArray::Index array_index(array_multi_index, reduce->shape(),
1842 b_.getInt64Ty());
1843 TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accumulator,
1844 EmitInnerLoopForVectorizedReduction(
1845 reduction_generator, array_index, vector_type,
1846 init_value, arg, dimensions, element_alignment));
1847
1848 llvm_ir::IrArray target_array = GetIrArrayFor(reduce);
1849 llvm::Value* output_address =
1850 target_array.EmitArrayElementAddress(array_index, &b_);
1851 EmitShardedVectorStore(output_address, accumulator, element_alignment,
1852 target_array);
1853
1854 if (auto exit_terminator = loop->GetExitBasicBlock()->getTerminator()) {
1855 CHECK_GT(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1);
1856 b_.SetInsertPoint(exit_terminator);
1857 } else {
1858 CHECK_EQ(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1);
1859 b_.SetInsertPoint(loop->GetExitBasicBlock());
1860 }
1861 }
1862
1863 // Since we increment the stride for the inner dimension by more than 1, we
1864 // may need to peel out an "epilogue" iteration to get the remaining elements
1865 // in the following case:
1866 if (innermost_dimension_size % vectorization_factor) {
1867 // TODO(b/63775531): Consider using a scalar loop here to save on code size.
1868 array_multi_index[innermost_dimension] =
1869 b_.getInt64(innermost_dimension_size -
1870 (innermost_dimension_size % vectorization_factor));
1871
1872 ShardedVectorType vector_type = CreateShardedVectorType(
1873 reduce->shape().element_type(),
1874 innermost_dimension_size % vectorization_factor);
1875 llvm_ir::IrArray::Index array_index(array_multi_index, reduce->shape(),
1876 b_.getInt64Ty());
1877 llvm::IRBuilderBase::FastMathFlagGuard guard(b_);
1878 llvm::FastMathFlags flags = b_.getFastMathFlags();
1879 flags.setAllowReassoc(true);
1880 b_.setFastMathFlags(flags);
1881 TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accumulator,
1882 EmitInnerLoopForVectorizedReduction(
1883 reduction_generator, array_index, vector_type,
1884 init_value, arg, dimensions, element_alignment));
1885
1886 llvm_ir::IrArray target_array = GetIrArrayFor(reduce);
1887 llvm::Value* output_address =
1888 target_array.EmitArrayElementAddress(array_index, &b_);
1889 EmitShardedVectorStore(output_address, accumulator, element_alignment,
1890 target_array);
1891 }
1892
1893 if (outermost_loop_exit_block) {
1894 b_.SetInsertPoint(outermost_loop_exit_block);
1895 }
1896
1897 return true;
1898 }
1899
HandleReduce(HloInstruction * reduce)1900 Status IrEmitter::HandleReduce(HloInstruction* reduce) {
1901 auto arg = reduce->mutable_operand(0);
1902 auto init_value = reduce->mutable_operand(1);
1903 absl::Span<const int64_t> dimensions(reduce->dimensions());
1904 HloComputation* function = reduce->to_apply();
1905 bool saved_allow_reassociation = allow_reassociation_;
1906 allow_reassociation_ = true;
1907 auto cleanup = absl::MakeCleanup([saved_allow_reassociation, this]() {
1908 allow_reassociation_ = saved_allow_reassociation;
1909 });
1910 if (!options::VectorizedReduceDisabled(hlo_module_config_)) {
1911 std::string vectorization_failure_reason;
1912 TF_ASSIGN_OR_RETURN(
1913 bool vectorization_successful,
1914 EmitVectorizedReduce(reduce, arg, init_value, dimensions, function,
1915 &vectorization_failure_reason));
1916 if (vectorization_successful) {
1917 VLOG(1) << "Successfully vectorized reduction " << reduce->ToString()
1918 << "\n";
1919 return OkStatus();
1920 } else {
1921 VLOG(1) << "Could not vectorize reduction " << reduce->ToString() << ": "
1922 << vectorization_failure_reason;
1923 }
1924 }
1925
1926 return DefaultAction(reduce);
1927 }
1928
HandleSend(HloInstruction * send)1929 Status IrEmitter::HandleSend(HloInstruction* send) {
1930 // TODO(b/33942983): Support Send/Recv on CPU.
1931 return Unimplemented("Send is not implemented on CPU.");
1932 }
1933
HandleSendDone(HloInstruction * send_done)1934 Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
1935 // TODO(b/33942983): Support Send/Recv on CPU.
1936 return Unimplemented("Send-done is not implemented on CPU.");
1937 }
1938
HandleScatter(HloInstruction *)1939 Status IrEmitter::HandleScatter(HloInstruction*) {
1940 return Unimplemented("Scatter is not implemented on CPUs.");
1941 }
1942
HandleSlice(HloInstruction * slice)1943 Status IrEmitter::HandleSlice(HloInstruction* slice) {
1944 VLOG(2) << "HandleSlice: " << slice->ToString();
1945 auto operand = slice->operand(0);
1946 // The code below emits a sequential loop nest. For the parallel backend, use
1947 // ParallelLoopEmitter which respects dynamic loop bounds.
1948 if (ShouldEmitParallelLoopFor(*slice)) {
1949 return DefaultAction(slice);
1950 }
1951
1952 // The code below assumes the layouts are equal.
1953 if (!LayoutUtil::Equal(operand->shape().layout(), slice->shape().layout())) {
1954 return DefaultAction(slice);
1955 }
1956
1957 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice));
1958
1959 if (ShapeUtil::IsZeroElementArray(slice->shape())) {
1960 return OkStatus();
1961 }
1962
1963 const Layout& layout = operand->shape().layout();
1964 const int64_t num_dims = operand->shape().dimensions_size();
1965
1966 // The slice lowering finds maximal contiguous blocks of memory that can be
1967 // copied from the source to the target. This is done by looking at the
1968 // source/target layout in minor to major order and do the following:
1969 //
1970 // * Find an initial segment of dimensions along which the slice uses the
1971 // whole dimension. These are the "inner" dimensions and can be folded into
1972 // the memcpy.
1973 //
1974 // * Of the remaining dimensions decide which ones require loops.
1975 //
1976 // * Implement the memcpy within the innermost loop.
1977
1978 absl::flat_hash_set<int64_t> inner_dims;
1979 for (int64_t dim : LayoutUtil::MinorToMajor(layout)) {
1980 if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) {
1981 break;
1982 }
1983 inner_dims.insert(dim);
1984 }
1985
1986 const bool is_trivial_copy = (inner_dims.size() == num_dims);
1987 if (is_trivial_copy) {
1988 if (ShapeUtil::IsEffectiveScalar(slice->shape())) {
1989 return DefaultAction(slice);
1990 } else {
1991 return EmitMemcpy(*slice, *operand);
1992 }
1993 }
1994
1995 // The memcpy will copy elements that are logically this shape (allowed to be
1996 // scalar).
1997 const Shape logical_element_shape = ShapeUtil::FilterDimensions(
1998 [&inner_dims](int64_t dim) { return inner_dims.contains(dim); },
1999 operand->shape());
2000
2001 const int64_t primitive_elements_per_logical_element =
2002 ShapeUtil::ElementsIn(logical_element_shape);
2003
2004 // memcpy_dim is the innermost (in terms of layout) dimension for which the
2005 // slice does *not* just copy all the elements along the dimension.
2006 const int64_t memcpy_dim = LayoutUtil::Minor(layout, inner_dims.size());
2007
2008 const bool memcpy_is_contiguous = slice->slice_strides(memcpy_dim) == 1;
2009 // The number of logical elements that can be copied in a single call
2010 // to memcpy. We can only copy 1 element at a time if there is a non-trivial
2011 // stride.
2012 const int64_t memcpy_logical_elements =
2013 memcpy_is_contiguous
2014 ? slice->slice_limits(memcpy_dim) - slice->slice_starts(memcpy_dim)
2015 : 1;
2016
2017 // Determine the dimensions that get lowered as loops.
2018 llvm::SmallVector<int64_t> outer_dims;
2019 for (int64_t i = 0; i < num_dims - inner_dims.size() - 1; ++i) {
2020 outer_dims.push_back(LayoutUtil::Major(layout, i));
2021 }
2022
2023 // Is the slice along the memcpy dimension contiguous? If not, then memcpy_dim
2024 // needs to be wrapped around a loop as well.
2025 if (!memcpy_is_contiguous) {
2026 outer_dims.push_back(memcpy_dim);
2027 }
2028
2029 llvm_ir::IrArray target_array = GetIrArrayFor(slice);
2030
2031 const int64_t num_outer_loops = outer_dims.size();
2032 llvm_ir::ForLoopNest loops(IrName(slice), &b_);
2033 std::vector<llvm::Value*> target_multi_index =
2034 loops.AddLoopsForShapeOnDimensions(slice->shape(), outer_dims, "slice");
2035
2036 // Only the indices for the outer dimensions have been initialized in
2037 // target_index. The rest of the indices should get initialized to 0, since
2038 // for the rest of the dimensions the copy writes to the full dimension.
2039 std::replace(target_multi_index.begin(), target_multi_index.end(),
2040 static_cast<llvm::Value*>(nullptr),
2041 static_cast<llvm::Value*>(b_.getInt64(0)));
2042 llvm_ir::IrArray::Index target_index(target_multi_index, slice->shape(),
2043 b_.getInt64Ty());
2044
2045 if (num_outer_loops > 0) {
2046 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
2047 }
2048
2049 llvm_ir::IrArray source_array = GetIrArrayFor(operand);
2050 const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice(
2051 /*operand_shape=*/operand->shape(), /*starts=*/slice->slice_starts(),
2052 /*strides=*/slice->slice_strides(), /*builder=*/&b_);
2053
2054 llvm::Value* memcpy_dest =
2055 target_array.EmitArrayElementAddress(target_index, &b_, "slice.dest");
2056 llvm::Value* memcpy_source =
2057 source_array.EmitArrayElementAddress(source_index, &b_, "slice.source");
2058
2059 const int64_t memcpy_elements =
2060 primitive_elements_per_logical_element * memcpy_logical_elements;
2061
2062 EmitTransferElements(memcpy_dest, memcpy_source, memcpy_elements,
2063 slice->shape().element_type(), target_array,
2064 source_array);
2065
2066 if (VLOG_IS_ON(2)) {
2067 const int64_t memcpy_bytes =
2068 ShapeUtil::ByteSizeOf(logical_element_shape) * memcpy_elements;
2069 VLOG(2) << " emitted copy of " << memcpy_bytes << " bytes inside "
2070 << num_outer_loops << " loops";
2071 }
2072
2073 if (num_outer_loops > 0) {
2074 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
2075 }
2076
2077 return OkStatus();
2078 }
2079
HandleDynamicSlice(HloInstruction * dynamic_slice)2080 Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice) {
2081 if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
2082 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_slice));
2083 return EmitMemcpy(*dynamic_slice->operand(0), *dynamic_slice);
2084 }
2085 return DefaultAction(dynamic_slice);
2086 }
2087
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)2088 Status IrEmitter::HandleDynamicUpdateSlice(
2089 HloInstruction* dynamic_update_slice) {
2090 auto update = dynamic_update_slice->operand(1);
2091 if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) {
2092 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice));
2093 return EmitMemcpy(*update, *dynamic_update_slice);
2094 } else if (llvm_ir::CanUpdateDynamicSliceInPlace(dynamic_update_slice,
2095 assignment_)) {
2096 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice));
2097 auto operands = GetIrArraysForOperandsOf(dynamic_update_slice);
2098 return llvm_ir::EmitDynamicUpdateSliceInPlace(
2099 operands, GetIrArrayFor(dynamic_update_slice),
2100 IrName(dynamic_update_slice, "in_place"), &b_);
2101 }
2102 return DefaultAction(dynamic_update_slice);
2103 }
2104
HandleRecv(HloInstruction * recv)2105 Status IrEmitter::HandleRecv(HloInstruction* recv) {
2106 // TODO(b/33942983): Support Send/Recv on CPU.
2107 return Unimplemented("Recv is not implemented on CPU.");
2108 }
2109
HandleRecvDone(HloInstruction * recv_done)2110 Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) {
2111 // TODO(b/33942983): Support Send/Recv on CPU.
2112 return Unimplemented("Recv-done is not implemented on CPU.");
2113 }
2114
HandlePad(HloInstruction * pad)2115 Status IrEmitter::HandlePad(HloInstruction* pad) {
2116 // CPU backend does not properly handle negative padding but this is ok
2117 // because negative padding should be removed by the algebraic simplifier.
2118 for (auto& padding_dimension : pad->padding_config().dimensions()) {
2119 if (padding_dimension.edge_padding_low() < 0 ||
2120 padding_dimension.edge_padding_high() < 0) {
2121 return InternalErrorStrCat(
2122 "Encountered negative padding in IrEmitter on CPU. "
2123 "This should have been eliminated at the HLO level. ",
2124 pad->ToString());
2125 }
2126 }
2127
2128 // First, fill in the padding value to all output elements.
2129 TF_RETURN_IF_ERROR(EmitTargetElementLoop(
2130 pad, "initialize",
2131 [this, pad](const llvm_ir::IrArray::Index& target_index) {
2132 const HloInstruction* padding_value = pad->operand(1);
2133 llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value);
2134 return Load(IrShapeType(padding_value->shape()), padding_value_addr);
2135 }));
2136
2137 // Create a loop to iterate over the operand elements and update the output
2138 // locations where the operand elements should be stored.
2139 llvm_ir::ForLoopNest loops(IrName(pad, "assign"), &b_);
2140 const HloInstruction* operand = pad->operand(0);
2141 const llvm_ir::IrArray::Index operand_index =
2142 loops.AddLoopsForShape(operand->shape(), "operand");
2143
2144 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
2145
2146 // Load an element from the operand.
2147 llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
2148 llvm::Value* operand_data =
2149 operand_array.EmitReadArrayElement(operand_index, &b_);
2150
2151 // Compute the output index the operand element should be assigned to.
2152 // output_index := edge_padding_low + operand_index * (interior_padding + 1)
2153 const PaddingConfig& padding_config = pad->padding_config();
2154 std::vector<llvm::Value*> output_multi_index;
2155 for (size_t i = 0; i < operand_index.size(); ++i) {
2156 llvm::Value* offset =
2157 Mul(operand_index[i],
2158 b_.getInt64(padding_config.dimensions(i).interior_padding() + 1));
2159 llvm::Value* index = Add(
2160 offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low()));
2161 output_multi_index.push_back(index);
2162 }
2163
2164 // Store the operand element to the computed output location.
2165 llvm_ir::IrArray output_array(GetIrArrayFor(pad));
2166 llvm_ir::IrArray::Index output_index(
2167 output_multi_index, output_array.GetShape(), operand_index.GetType());
2168 output_array.EmitWriteArrayElement(output_index, operand_data, &b_);
2169
2170 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
2171 return OkStatus();
2172 }
2173
HandleFusion(HloInstruction * fusion)2174 Status IrEmitter::HandleFusion(HloInstruction* fusion) {
2175 auto* root = fusion->fused_expression_root();
2176 if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) {
2177 VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace";
2178 CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
2179 FusedIrEmitter fused_emitter(elemental_emitter);
2180 BindFusionArguments(fusion, &fused_emitter);
2181
2182 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
2183 // Delegate to common implementation of fused in-place dynamic-update-slice.
2184 return llvm_ir::EmitFusedDynamicUpdateSliceInPlace(
2185 fusion, GetIrArrayFor(fusion), &fused_emitter, &b_);
2186 } else if (fusion->IsLoopFusion()) {
2187 VLOG(3) << "HandleFusion kLoop";
2188 CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
2189 FusedIrEmitter fused_emitter(elemental_emitter);
2190 BindFusionArguments(fusion, &fused_emitter);
2191 TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator(
2192 *fusion->fused_expression_root()));
2193 return EmitTargetElementLoop(fusion, generator);
2194 } else if (fusion->IsOutputFusion()) {
2195 VLOG(3) << "HandleFusion kOutput";
2196 int64_t dot_op_index =
2197 root->operand(0)->opcode() == HloOpcode::kDot ? 0 : 1;
2198 const HloInstruction* dot = root->operand(dot_op_index);
2199 CHECK_EQ(dot->opcode(), HloOpcode::kDot)
2200 << dot->ToString() << " "
2201 << fusion->fused_instructions_computation()->ToString();
2202
2203 int64_t dot_lhs_param_number = dot->operand(0)->parameter_number();
2204 int64_t dot_rhs_param_number = dot->operand(1)->parameter_number();
2205 int64_t addend_param_number =
2206 root->operand(1 - dot_op_index)->parameter_number();
2207
2208 Shape target_shape = fusion->shape();
2209 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
2210 llvm_ir::IrArray target_array = GetIrArrayFor(fusion);
2211
2212 llvm_ir::IrArray lhs_array(
2213 GetIrArrayFor(fusion->operand(dot_lhs_param_number)));
2214 llvm_ir::IrArray rhs_array(
2215 GetIrArrayFor(fusion->operand(dot_rhs_param_number)));
2216 llvm_ir::IrArray addend_array(
2217 GetIrArrayFor(fusion->operand(addend_param_number)));
2218
2219 TF_RETURN_IF_ERROR(EmitDotOperation(
2220 *dot, target_array, lhs_array, rhs_array, &addend_array,
2221 GetExecutableRunOptionsArgument(), &b_, mlir_context_,
2222 hlo_module_config_, target_machine_features_));
2223 return OkStatus();
2224 } else {
2225 return Unimplemented("Fusion kind not implemented on CPU");
2226 }
2227 }
2228
HandleCall(HloInstruction * call)2229 Status IrEmitter::HandleCall(HloInstruction* call) {
2230 HloComputation* computation = call->to_apply();
2231 llvm::Function* call_ir_function = FindOrDie(
2232 emitted_functions_, ComputationToEmit{computation, allow_reassociation_});
2233
2234 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call));
2235
2236 if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
2237 // Having a nonempty set of 'outer_dimension_partitions' means that this
2238 // computation has been specially selected to be parallelized (one where the
2239 // root instruction is trivially parallelizable, like elementwise addition
2240 // of two tensors). The LLVM function generated for this computation accepts
2241 // an additional set of loop bounds, allowing the caller to control the
2242 // subset of the output that is generated by each call.
2243
2244 std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments(
2245 {}, &b_, computation->name(),
2246 /*return_value_buffer=*/emitted_value_[call],
2247 /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
2248 /*buffer_table_arg=*/GetBufferTableArgument(),
2249 /*status_arg=*/GetStatusArgument(),
2250 /*profile_counters_arg=*/GetProfileCountersArgument());
2251
2252 // The parallel fork/join runtime will call the generated function once for
2253 // each partition in parallel, using an appropriate set of loop bounds for
2254 // each call such that it only generates one partition of the output.
2255 HloInstruction* root = computation->root_instruction();
2256 TF_RETURN_IF_ERROR(EmitCallToParallelForkJoin(
2257 call_args, root->shape(), root->outer_dimension_partitions(), &b_,
2258 call_ir_function, computation->name()));
2259
2260 if (ComputationTransitivelyContainsCustomCall(computation)) {
2261 EmitEarlyReturnIfErrorStatus();
2262 }
2263 } else {
2264 EmitGlobalCall(*computation, computation->name());
2265 }
2266
2267 return OkStatus();
2268 }
2269
HandleSliceToDynamic(HloInstruction * hlo)2270 Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) {
2271 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
2272 std::vector<llvm::Value*> dynamic_dims;
2273 int32_t raw_data_size =
2274 ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(hlo->shape()));
2275 llvm::Value* dest_buffer = GetEmittedValueFor(hlo);
2276 llvm::Value* raw_buffer =
2277 b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo());
2278 for (int64_t i = 1; i < hlo->operand_count(); ++i) {
2279 const int64_t dim_index = i - 1;
2280 llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(i));
2281 llvm::LoadInst* dyn_dim_size = Load(IrShapeType(hlo->operand(i)->shape()),
2282 source_buffer, "dyn_dim_size");
2283
2284 llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
2285 b_.getInt8Ty(), raw_buffer,
2286 raw_data_size + dim_index * sizeof(int32_t));
2287 b_.CreateStore(dyn_dim_size,
2288 b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()));
2289 dynamic_dims.push_back(b_.CreateIntCast(dyn_dim_size, b_.getInt64Ty(),
2290 /*isSigned=*/true,
2291 "i64_dyn_dim_size"));
2292 }
2293
2294 llvm_ir::IrArray data_array = GetIrArrayFor(hlo);
2295 // Pseudo code for sliceToDynamic:
2296 //
2297 // for (index i in dynamic_dim)
2298 // dest_index = delinearize(linearize(i, dynamic_dim), static_dim)
2299 // dest[dest_index] = source[i]
2300 auto loop_body_emitter =
2301 [&](const llvm_ir::IrArray::Index& array_index) -> Status {
2302 llvm::Value* source_element =
2303 GetIrArrayFor(hlo->operand(0)).EmitReadArrayElement(array_index, &b_);
2304 llvm::Value* linear_index = array_index.Linearize(dynamic_dims, &b_);
2305 // Delinearize the index based on the static shape.
2306 llvm_ir::IrArray::Index dest_index(linear_index, data_array.GetShape(),
2307 &b_);
2308 data_array.EmitWriteArrayElement(dest_index, source_element, &b_);
2309 return OkStatus();
2310 };
2311 return llvm_ir::LoopEmitter(loop_body_emitter, data_array.GetShape(),
2312 dynamic_dims, &b_)
2313 .EmitLoop(IrName(hlo));
2314 }
2315
HandlePadToStatic(HloInstruction * hlo)2316 Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) {
2317 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
2318
2319 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
2320 assignment_.GetUniqueSlice(hlo, {0}));
2321 std::vector<llvm::Value*> dynamic_dims;
2322 std::vector<llvm::Value*> tuple_operand_ptrs;
2323 const Shape& data_shape = ShapeUtil::GetSubshape(hlo->shape(), {0});
2324 const Shape& input_shape = hlo->operand(0)->shape();
2325 llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape);
2326 llvm::Type* data_type = IrShapeType(data_shape);
2327 llvm_ir::IrArray data_array(data_address, data_type, data_shape);
2328 llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(0));
2329 llvm::Value* raw_buffer =
2330 b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo());
2331 int64_t raw_data_size =
2332 ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(input_shape));
2333
2334 // Put a placeholder for the data array's pointer
2335 tuple_operand_ptrs.push_back(data_array.GetBasePointer());
2336 // PadToStatic has a dynamic tensor as input and variadic size of outputs:
2337 // (static_tensor, dynamic_dim_0, dynamic_dim_1, ... )
2338 // Dynamic dimension sizes starts from output index 1.
2339 for (int i = 1; i < hlo->shape().tuple_shapes_size(); ++i) {
2340 // Read from the metadata section of the dynamic input (operand 0).
2341 const Shape& dim_shape = ShapeUtil::GetSubshape(hlo->shape(), {i});
2342 TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32)));
2343 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dim_size_slice,
2344 assignment_.GetUniqueSlice(hlo, {i}));
2345 llvm::Value* dest_dim_size_address =
2346 EmitBufferPointer(dim_size_slice, data_shape);
2347 const int64_t dim_index = i - 1;
2348 llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
2349 b_.getInt8Ty(), raw_buffer,
2350 raw_data_size + dim_index * sizeof(int32_t));
2351 llvm::Value* dyn_dim_size = b_.CreateLoad(
2352 b_.getInt32Ty(),
2353 b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()),
2354 "dyn_dim_size");
2355 b_.CreateStore(dyn_dim_size,
2356 b_.CreateBitCast(dest_dim_size_address,
2357 b_.getInt32Ty()->getPointerTo()));
2358 dynamic_dims.push_back(b_.CreateIntCast(dyn_dim_size, b_.getInt64Ty(),
2359 /*isSigned=*/true,
2360 "i64_dyn_dim_size"));
2361 tuple_operand_ptrs.push_back(dest_dim_size_address);
2362 }
2363
2364 // Pseudo code for padToStatic:
2365 //
2366 // for (index i in dynamic_dim)
2367 // source_index = delinearize(inearize(i, dynamic_dim), static_dim)
2368 // dest[i] = source[source_index]
2369 auto loop_body_emitter =
2370 [&](const llvm_ir::IrArray::Index& array_index) -> Status {
2371 llvm::Value* linear_index = array_index.Linearize(dynamic_dims, &b_);
2372 llvm_ir::IrArray::Index source_index(linear_index, input_shape, &b_);
2373 llvm::Value* source_element =
2374 GetIrArrayFor(hlo->operand(0)).EmitReadArrayElement(source_index, &b_);
2375 data_array.EmitWriteArrayElement(array_index, source_element, &b_);
2376 return OkStatus();
2377 };
2378 TF_RETURN_IF_ERROR(
2379 llvm_ir::LoopEmitter(loop_body_emitter, input_shape, dynamic_dims, &b_)
2380 .EmitLoop(IrName(hlo)));
2381
2382 // Emit static tensor and dynamic sizes as one tuple.
2383 llvm_ir::EmitTuple(GetIrArrayFor(hlo), tuple_operand_ptrs, &b_);
2384 return OkStatus();
2385 }
2386
HandleTopK(HloInstruction * hlo)2387 Status IrEmitter::HandleTopK(HloInstruction* hlo) {
2388 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
2389 const HloInstruction* input = hlo->operand(0);
2390 const int64_t k = hlo->shape().tuple_shapes(0).dimensions().back();
2391 const bool has_batch = hlo->shape().tuple_shapes(0).dimensions_size() == 2;
2392 TF_RET_CHECK(input->shape().element_type() == F32);
2393 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
2394 hlo->shape().tuple_shapes(0).layout()));
2395 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
2396 hlo->shape().tuple_shapes(1).layout()));
2397 TF_RET_CHECK(
2398 LayoutUtil::IsMonotonicWithDim0Major(hlo->operand(0)->shape().layout()));
2399
2400 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice values_slice,
2401 assignment_.GetUniqueSlice(hlo->operand(0), {}));
2402 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_values_slice,
2403 assignment_.GetUniqueSlice(hlo, {0}));
2404 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_indices_slice,
2405 assignment_.GetUniqueSlice(hlo, {1}));
2406 llvm::Value* values_ptr =
2407 EmitBufferPointer(values_slice, hlo->operand(0)->shape());
2408 llvm::Value* out_values_ptr =
2409 EmitBufferPointer(out_values_slice, hlo->shape().tuple_shapes(0));
2410 llvm::Value* out_indices_ptr =
2411 EmitBufferPointer(out_indices_slice, hlo->shape().tuple_shapes(1));
2412 EmitCallToFunc(
2413 runtime::kTopKF32SymbolName,
2414 {b_.getInt64(has_batch ? input->shape().dimensions(0) : 1),
2415 b_.getInt64(input->shape().dimensions().back()), b_.getInt64(k),
2416 BitCast(values_ptr, b_.getFloatTy()->getPointerTo()),
2417 BitCast(out_values_ptr, b_.getFloatTy()->getPointerTo()),
2418 BitCast(out_indices_ptr, b_.getInt32Ty()->getPointerTo())},
2419 b_.getVoidTy());
2420
2421 llvm_ir::EmitTuple(GetIrArrayFor(hlo), {out_values_ptr, out_indices_ptr},
2422 &b_);
2423 return OkStatus();
2424 }
2425
HandleCustomCall(HloInstruction * custom_call)2426 Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
2427 if (custom_call->custom_call_target() == "PadToStatic") {
2428 return HandlePadToStatic(custom_call);
2429 }
2430 if (custom_call->custom_call_target() == "SliceToDynamic") {
2431 return HandleSliceToDynamic(custom_call);
2432 }
2433 if (custom_call->custom_call_target() == "TopK") {
2434 return HandleTopK(custom_call);
2435 }
2436
2437 absl::Span<HloInstruction* const> operands(custom_call->operands());
2438 llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
2439 llvm::AllocaInst* operands_alloca =
2440 llvm_ir::EmitAllocaAtFunctionEntryWithCount(
2441 i8_ptr_type, b_.getInt32(operands.size()), "cc_operands_alloca", &b_);
2442 for (size_t i = 0; i < operands.size(); ++i) {
2443 const HloInstruction* operand = operands[i];
2444 llvm::Value* operand_as_i8ptr =
2445 PointerCast(GetEmittedValueFor(operand), i8_ptr_type);
2446 llvm::Value* slot_in_operands_alloca = InBoundsGEP(
2447 operands_alloca->getAllocatedType(), operands_alloca, {b_.getInt64(i)});
2448 Store(operand_as_i8ptr, slot_in_operands_alloca);
2449 }
2450 if (emit_code_for_msan_) {
2451 // Mark the alloca as initialized for msan. The buffer gets read by the
2452 // custom callee, which might be msan-instrumented.
2453 // TODO(b/66051036): Run the msan instrumentation pass instead.
2454 const llvm::DataLayout& dl = module_->getDataLayout();
2455 llvm::Type* intptr_type = b_.getIntPtrTy(dl);
2456 EmitCallToFunc(
2457 "__msan_unpoison",
2458 {PointerCast(operands_alloca, i8_ptr_type),
2459 llvm::ConstantInt::get(
2460 intptr_type, *operands_alloca->getAllocationSizeInBits(dl) / 8)},
2461 b_.getVoidTy());
2462 }
2463
2464 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call));
2465 // Write the tuple table if the output is a tuple.
2466 if (custom_call->shape().IsTuple()) {
2467 std::vector<llvm::Value*> base_ptrs;
2468 for (int i = 0; i < ShapeUtil::TupleElementCount(custom_call->shape());
2469 ++i) {
2470 const Shape& elem_shape =
2471 ShapeUtil::GetTupleElementShape(custom_call->shape(), i);
2472 TF_RET_CHECK(!elem_shape.IsTuple()) << "Nested tuples not implemented";
2473 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
2474 assignment_.GetUniqueSlice(custom_call, {i}));
2475 llvm::Value* addr = EmitBufferPointer(slice, elem_shape);
2476 base_ptrs.push_back(addr);
2477 }
2478 llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_);
2479 }
2480 auto* output_address_arg =
2481 PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type);
2482
2483 auto typed_custom_call = Cast<HloCustomCallInstruction>(custom_call);
2484 switch (typed_custom_call->api_version()) {
2485 case CustomCallApiVersion::API_VERSION_ORIGINAL:
2486 EmitCallToFunc(custom_call->custom_call_target(),
2487 {output_address_arg, operands_alloca}, b_.getVoidTy());
2488 break;
2489 case CustomCallApiVersion::API_VERSION_STATUS_RETURNING:
2490 EmitCallToFunc(custom_call->custom_call_target(),
2491 {output_address_arg, operands_alloca, GetStatusArgument()},
2492 b_.getVoidTy());
2493 EmitEarlyReturnIfErrorStatus();
2494 break;
2495 case CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED: {
2496 absl::string_view opaque = typed_custom_call->opaque();
2497 EmitCallToFunc(custom_call->custom_call_target(),
2498 {output_address_arg, operands_alloca,
2499 b_.CreateGlobalStringPtr(llvm_ir::AsStringRef(opaque)),
2500 b_.getInt64(opaque.size()), GetStatusArgument()},
2501 b_.getVoidTy());
2502 EmitEarlyReturnIfErrorStatus();
2503 break;
2504 }
2505 default:
2506 return InternalError(
2507 "Unknown custom-call API version enum value: %d (%s)",
2508 typed_custom_call->api_version(),
2509 CustomCallApiVersion_Name(typed_custom_call->api_version()));
2510 }
2511
2512 return OkStatus();
2513 }
2514
HandleWhile(HloInstruction * xla_while)2515 Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
2516 // Precondition: Condition computation must return a scalar bool.
2517 HloComputation* condition = xla_while->while_condition();
2518 TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) &&
2519 condition->root_instruction()->shape().element_type() == PRED)
2520 << "While condition computation must return bool; got: "
2521 << ShapeUtil::HumanString(condition->root_instruction()->shape());
2522 // Check that all while-related buffers share an allocation slice.
2523 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
2524 xla_while->shape(),
2525 [this, &xla_while](const Shape& /*subshape*/,
2526 const ShapeIndex& index) -> Status {
2527 auto check = [this](const HloInstruction* a, const HloInstruction* b,
2528 const ShapeIndex& index) {
2529 const BufferAllocation::Slice slice_a =
2530 assignment_.GetUniqueSlice(a, index).value();
2531 const BufferAllocation::Slice slice_b =
2532 assignment_.GetUniqueSlice(b, index).value();
2533 if (slice_a != slice_b) {
2534 return InternalError(
2535 "instruction %s %s does not share slice with "
2536 "instruction %s %s",
2537 a->ToString(), slice_a.ToString(), b->ToString(),
2538 slice_b.ToString());
2539 }
2540 return OkStatus();
2541 };
2542 TF_RETURN_IF_ERROR(check(xla_while, xla_while->operand(0), index));
2543 TF_RETURN_IF_ERROR(check(
2544 xla_while, xla_while->while_condition()->parameter_instruction(0),
2545 index));
2546 TF_RETURN_IF_ERROR(
2547 check(xla_while, xla_while->while_body()->parameter_instruction(0),
2548 index));
2549 TF_RETURN_IF_ERROR(check(
2550 xla_while, xla_while->while_body()->root_instruction(), index));
2551 return OkStatus();
2552 }));
2553
2554 // Set emitted value to that of 'init' with which it shares an allocation.
2555 const HloInstruction* init = xla_while->operand(0);
2556 emitted_value_[xla_while] = GetEmittedValueFor(init);
2557
2558 // Generating:
2559 // while (Condition(while_result)) {
2560 // // CopyInsertion pass inserts copies which enable 'while_result' to
2561 // // be passed back in as 'Body' parameter.
2562 // while_result = Body(while_result); // Insert
2563 // }
2564
2565 // Terminates the current block with a branch to a while header.
2566 llvm::BasicBlock* header_bb = llvm::BasicBlock::Create(
2567 module_->getContext(), IrName(xla_while, "header"),
2568 compute_function_->function());
2569 Br(header_bb);
2570 b_.SetInsertPoint(header_bb);
2571
2572 // Calls the condition function to determine whether to proceed with the
2573 // body. It must return a bool, so use the scalar call form.
2574 EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond"));
2575 llvm::Value* while_predicate = ICmpNE(
2576 Load(IrShapeType(
2577 xla_while->while_condition()->root_instruction()->shape()),
2578 GetBufferForGlobalCallReturnValue(*xla_while->while_condition())),
2579 llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0));
2580
2581 // Branches to the body or to the while exit depending on the condition.
2582 llvm::BasicBlock* body_bb =
2583 llvm::BasicBlock::Create(module_->getContext(), IrName(xla_while, "body"),
2584 compute_function_->function());
2585 llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create(
2586 module_->getContext(), IrName(xla_while, "exit"));
2587 CondBr(while_predicate, body_bb, exit_bb);
2588
2589 // Calls the body function from the body block.
2590 b_.SetInsertPoint(body_bb);
2591
2592 // Calls the body function.
2593 EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body"));
2594
2595 // Finishes with a branch back to the header.
2596 Br(header_bb);
2597
2598 // Adds the exit block to the function and sets the insert point there.
2599 compute_function_->function()->getBasicBlockList().push_back(exit_bb);
2600 b_.SetInsertPoint(exit_bb);
2601
2602 return OkStatus();
2603 }
2604
EmitFastConcatenate(HloInstruction * concatenate,absl::Span<HloInstruction * const> operands,std::string * failure_reason)2605 StatusOr<bool> IrEmitter::EmitFastConcatenate(
2606 HloInstruction* concatenate, absl::Span<HloInstruction* const> operands,
2607 std::string* failure_reason) {
2608 if (ShouldEmitParallelLoopFor(*concatenate)) {
2609 *failure_reason =
2610 "cannot generate memcpy-based concat for the parallel CPU backend";
2611 return false;
2612 }
2613
2614 const Shape& output_shape = concatenate->shape();
2615 for (auto* op : operands) {
2616 if (!LayoutUtil::Equal(op->shape().layout(), output_shape.layout())) {
2617 *failure_reason = "operand has mismatching layouts";
2618 return false;
2619 }
2620 }
2621
2622 // We split the dimensions into three categories: the dimension over which we
2623 // are concatenating (concat_dim), the dimensions that are minor to it
2624 // (inner_dims) and the dimensions that are major to it (outer_dims).
2625
2626 int64_t concat_dim = concatenate->dimensions(0);
2627 const Layout& output_layout = output_shape.layout();
2628 auto output_min2maj = LayoutUtil::MinorToMajor(output_layout);
2629 auto concat_dim_layout_itr = absl::c_find(output_min2maj, concat_dim);
2630
2631 std::vector<int64_t> inner_dims(output_min2maj.begin(),
2632 concat_dim_layout_itr);
2633 std::vector<int64_t> outer_dims(std::next(concat_dim_layout_itr),
2634 output_min2maj.end());
2635
2636 llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
2637
2638 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate));
2639 llvm_ir::IrArray target_array = GetIrArrayFor(concatenate);
2640
2641 llvm_ir::ForLoopNest loops(IrName(concatenate), &b_);
2642 std::vector<llvm::Value*> target_multi_index =
2643 loops.AddLoopsForShapeOnDimensions(output_shape, outer_dims, "concat");
2644 std::replace(target_multi_index.begin(), target_multi_index.end(),
2645 static_cast<llvm::Value*>(nullptr),
2646 static_cast<llvm::Value*>(b_.getInt64(0)));
2647 llvm_ir::IrArray::Index target_index(target_multi_index, output_shape,
2648 b_.getInt64Ty());
2649
2650 if (!outer_dims.empty()) {
2651 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
2652 }
2653
2654 PrimitiveType primitive_type = output_shape.element_type();
2655 unsigned primitive_type_size =
2656 ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
2657
2658 // Contiguous subregions from each operand to the concatenate contribute to a
2659 // contiguous subregion in the target buffer starting at target_region_begin.
2660 llvm::Value* target_region_begin = BitCast(
2661 target_array.EmitArrayElementAddress(target_index, &b_, "target_region"),
2662 i8_ptr_type);
2663 int64_t byte_offset_into_target_region = 0;
2664
2665 int64_t inner_dims_product =
2666 std::accumulate(inner_dims.begin(), inner_dims.end(), 1l,
2667 [&](int64_t product, int64_t inner_dim) {
2668 return product * output_shape.dimensions(inner_dim);
2669 });
2670
2671 // For each operand, emit a memcpy from the operand to the target of size
2672 // equal to the product of inner dimensions.
2673 for (HloInstruction* operand : operands) {
2674 const Shape& input_shape = operand->shape();
2675 llvm_ir::IrArray source_array = GetIrArrayFor(operand);
2676 llvm_ir::IrArray::Index source_index(target_multi_index, operand->shape(),
2677 b_.getInt64Ty());
2678 llvm::Value* copy_source_address = BitCast(
2679 source_array.EmitArrayElementAddress(source_index, &b_, "src_addr"),
2680 i8_ptr_type);
2681
2682 llvm::Value* copy_target_address =
2683 GEP(b_.getInt8Ty(), target_region_begin,
2684 b_.getInt64(byte_offset_into_target_region));
2685
2686 EmitTransferElements(
2687 copy_target_address, copy_source_address,
2688 inner_dims_product * input_shape.dimensions(concat_dim), primitive_type,
2689 target_array, source_array);
2690
2691 byte_offset_into_target_region += inner_dims_product *
2692 input_shape.dimensions(concat_dim) *
2693 primitive_type_size;
2694 }
2695
2696 if (!outer_dims.empty()) {
2697 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
2698 }
2699
2700 return true;
2701 }
2702
EmitPrintf(absl::string_view fmt,absl::Span<llvm::Value * const> arguments)2703 llvm::Value* IrEmitter::EmitPrintf(absl::string_view fmt,
2704 absl::Span<llvm::Value* const> arguments) {
2705 llvm::Type* ptr_ty = b_.getInt8Ty()->getPointerTo();
2706 std::vector<llvm::Value*> call_args;
2707 call_args.push_back(b_.CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)));
2708 absl::c_copy(arguments, std::back_inserter(call_args));
2709 return b_.CreateCall(
2710 b_.GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
2711 "printf", llvm::FunctionType::get(b_.getInt32Ty(), {ptr_ty},
2712 /*isVarArg=*/true)),
2713 call_args);
2714 }
2715
EmitPrintfToStderr(absl::string_view fmt,absl::Span<llvm::Value * const> arguments)2716 llvm::Value* IrEmitter::EmitPrintfToStderr(
2717 absl::string_view fmt, absl::Span<llvm::Value* const> arguments) {
2718 llvm::Type* ptr_ty = b_.getInt8Ty()->getPointerTo();
2719 std::vector<llvm::Value*> call_args;
2720 call_args.push_back(b_.CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)));
2721 absl::c_copy(arguments, std::back_inserter(call_args));
2722 return b_.CreateCall(
2723 b_.GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
2724 runtime::kPrintfToStderrSymbolName,
2725 llvm::FunctionType::get(b_.getInt32Ty(), {ptr_ty},
2726 /*isVarArg=*/true)),
2727 call_args);
2728 }
2729
EmitCallToFunc(std::string func_name,const std::vector<llvm::Value * > & arguments,llvm::Type * return_type,bool does_not_throw,bool only_accesses_arg_memory,bool only_accesses_inaccessible_mem_or_arg_mem)2730 llvm::Value* IrEmitter::EmitCallToFunc(
2731 std::string func_name, const std::vector<llvm::Value*>& arguments,
2732 llvm::Type* return_type, bool does_not_throw, bool only_accesses_arg_memory,
2733 bool only_accesses_inaccessible_mem_or_arg_mem) {
2734 std::vector<llvm::Type*> types;
2735 types.reserve(arguments.size());
2736 absl::c_transform(arguments, std::back_inserter(types),
2737 [&](llvm::Value* val) { return val->getType(); });
2738 llvm::FunctionType* func_type =
2739 llvm::FunctionType::get(return_type, types, /*isVarArg=*/false);
2740 auto func = llvm::dyn_cast<llvm::Function>(
2741 module_->getOrInsertFunction(func_name, func_type).getCallee());
2742 func->setCallingConv(llvm::CallingConv::C);
2743 if (does_not_throw) {
2744 func->setDoesNotThrow();
2745 }
2746 if (only_accesses_arg_memory) {
2747 func->setOnlyAccessesArgMemory();
2748 }
2749 if (only_accesses_inaccessible_mem_or_arg_mem) {
2750 func->setOnlyAccessesInaccessibleMemOrArgMem();
2751 }
2752 return b_.CreateCall(func, arguments);
2753 }
2754
EmitTransferElements(llvm::Value * target,llvm::Value * source,int64_t element_count,PrimitiveType primitive_type,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & source_array)2755 void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
2756 int64_t element_count,
2757 PrimitiveType primitive_type,
2758 const llvm_ir::IrArray& target_array,
2759 const llvm_ir::IrArray& source_array) {
2760 unsigned primitive_type_size =
2761 ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
2762 llvm::Align element_alignment(tensorflow::MathUtil::GCD<unsigned>(
2763 primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type)));
2764 llvm::Type* primitive_llvm_type =
2765 llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
2766 llvm::Type* primitive_ptr_type =
2767 llvm::PointerType::getUnqual(primitive_llvm_type);
2768
2769 if (element_count == 1) {
2770 auto* load_instruction =
2771 AlignedLoad(primitive_llvm_type, BitCast(source, primitive_ptr_type),
2772 element_alignment);
2773 source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction);
2774 auto* store_instruction =
2775 AlignedStore(load_instruction, BitCast(target, primitive_ptr_type),
2776 element_alignment);
2777 target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction);
2778 } else {
2779 auto* memcpy_instruction = b_.CreateMemCpy(
2780 target, /*DstAlign=*/llvm::Align(element_alignment), source,
2781 /*SrcAlign=*/llvm::Align(element_alignment),
2782 element_count * primitive_type_size);
2783
2784 // The memcpy does the load and the store internally. The aliasing related
2785 // metadata has to reflect that.
2786 std::map<int, llvm::MDNode*> merged_metadata =
2787 llvm_ir::MergeMetadata(&module_->getContext(), source_array.metadata(),
2788 target_array.metadata());
2789 for (const auto& kind_md_pair : merged_metadata) {
2790 memcpy_instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
2791 }
2792 }
2793 }
2794
HandleConcatenate(HloInstruction * concatenate)2795 Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
2796 absl::Span<HloInstruction* const> operands(concatenate->operands());
2797 std::string failure_reason;
2798 TF_ASSIGN_OR_RETURN(
2799 bool successful,
2800 EmitFastConcatenate(concatenate, operands, &failure_reason));
2801 if (successful) {
2802 VLOG(1) << "Emitted fast concatenate for " << concatenate->ToString();
2803 return OkStatus();
2804 }
2805
2806 VLOG(1) << "Could not emit fast concatenate for " << concatenate->ToString()
2807 << ": " << failure_reason;
2808
2809 return DefaultAction(concatenate);
2810 }
2811
HandleConditional(HloInstruction * conditional)2812 Status IrEmitter::HandleConditional(HloInstruction* conditional) {
2813 auto branch_index = conditional->operand(0);
2814 int num_branches = conditional->branch_count();
2815 TF_RET_CHECK(ShapeUtil::IsScalar(branch_index->shape()) &&
2816 (branch_index->shape().element_type() == PRED ||
2817 branch_index->shape().element_type() == S32))
2818 << "Branch index on a conditional must be scalar bool or int32_t; got: "
2819 << ShapeUtil::HumanString(branch_index->shape());
2820
2821 for (int b = 0; b < num_branches; ++b) {
2822 HloComputation* br_computation = conditional->branch_computation(b);
2823 TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(),
2824 br_computation->root_instruction()->shape()))
2825 << "Shape of conditional should be same as the shape of the " << b
2826 << "th branch computation; got: "
2827 << ShapeUtil::HumanString(conditional->shape()) << " and "
2828 << ShapeUtil::HumanString(br_computation->root_instruction()->shape());
2829 }
2830
2831 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional));
2832
2833 if (branch_index->shape().element_type() == PRED) {
2834 // Emit an if-else to LLVM:
2835 // if (pred)
2836 // cond_result = true_computation(true_operand)
2837 // else
2838 // cond_result = false_computation(false_operand)
2839 llvm::LoadInst* pred_value = Load(
2840 GetIrArrayFor(branch_index).GetBasePointeeType(),
2841 GetIrArrayFor(branch_index).GetBasePointer(), "load_predicate_value");
2842 llvm::Value* pred_cond =
2843 ICmpNE(pred_value,
2844 llvm::ConstantInt::get(
2845 llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
2846 "boolean_predicate");
2847 llvm_ir::LlvmIfData if_data =
2848 llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_);
2849
2850 SetToFirstInsertPoint(if_data.true_block, &b_);
2851 EmitGlobalCall(*conditional->branch_computation(0),
2852 IrName(conditional, "_true"));
2853
2854 SetToFirstInsertPoint(if_data.false_block, &b_);
2855 EmitGlobalCall(*conditional->branch_computation(1),
2856 IrName(conditional, "_false"));
2857
2858 SetToFirstInsertPoint(if_data.after_block, &b_);
2859 return OkStatus();
2860 }
2861 // We emit a switch statement to LLVM:
2862 // switch (branch_index) {
2863 // default:
2864 // result = branch_computations[num_branches-1](operands[num_branches-1]);
2865 // break;
2866 // case 0:
2867 // result = branch_computations[0](operands[0]); break;
2868 // case 1:
2869 // result = branch_computations[1](operands[1]); break;
2870 // ...
2871 // case [[num_branches-2]]:
2872 // result = branch_computations[num_branches-2](operands[num_branches-2]);
2873 // break;
2874 // }
2875 llvm::LoadInst* branch_index_value = Load(
2876 GetIrArrayFor(branch_index).GetBasePointeeType(),
2877 GetIrArrayFor(branch_index).GetBasePointer(), "load_branch_index_value");
2878
2879 auto case_block = b_.GetInsertBlock();
2880 llvm::BasicBlock* after_block;
2881 // Add a terminator to the case block, if necessary.
2882 if (case_block->getTerminator() == nullptr) {
2883 after_block = llvm_ir::CreateBasicBlock(nullptr, "case-after", &b_);
2884 b_.SetInsertPoint(case_block);
2885 b_.CreateBr(after_block);
2886 } else {
2887 after_block =
2888 case_block->splitBasicBlock(b_.GetInsertPoint(), "case-after");
2889 }
2890 // Our basic block should now end with an unconditional branch. Remove it;
2891 // we're going to replace it with a switch based branch.
2892 case_block->getTerminator()->eraseFromParent();
2893
2894 // Lower the default branch computation.
2895 auto default_block = llvm_ir::CreateBasicBlock(nullptr, "case-default", &b_);
2896 b_.SetInsertPoint(default_block);
2897 EmitGlobalCall(*conditional->branch_computation(num_branches - 1),
2898 IrName(conditional, "_default"));
2899 b_.CreateBr(after_block);
2900
2901 // Prepare the switch (branch_index) { ... } instruction.
2902 b_.SetInsertPoint(case_block);
2903 llvm::SwitchInst* case_inst =
2904 b_.CreateSwitch(branch_index_value, default_block, num_branches - 1);
2905 // Lower each branch's computation.
2906 for (int b = 0; b < num_branches - 1; ++b) { // last branch is default
2907 // Lower the case b: { ... ; break; } computation.
2908 auto branch_block =
2909 llvm_ir::CreateBasicBlock(nullptr, absl::StrCat("case-branch", b), &b_);
2910 b_.SetInsertPoint(branch_block);
2911 EmitGlobalCall(*conditional->branch_computation(b),
2912 IrName(conditional, absl::StrCat("_branch", b)));
2913 b_.CreateBr(after_block);
2914 case_inst->addCase(b_.getInt32(b), branch_block);
2915 }
2916
2917 SetToFirstInsertPoint(after_block, &b_);
2918 return OkStatus();
2919 }
2920
HandleAfterAll(HloInstruction * after_all)2921 Status IrEmitter::HandleAfterAll(HloInstruction* after_all) {
2922 TF_RET_CHECK(ByteSizeOf(after_all->shape()) == 0);
2923 // No code to generate, but we need to emit an address for book-keeping.
2924 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(after_all));
2925 return OkStatus();
2926 }
2927
HandleAddDependency(HloInstruction * add_dependency)2928 Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
2929 // AddDedendency just forwards its zero-th operand.
2930 emitted_value_[add_dependency] =
2931 GetEmittedValueFor(add_dependency->operand(0));
2932 return OkStatus();
2933 }
2934
HandleRng(HloInstruction * rng)2935 Status IrEmitter::HandleRng(HloInstruction* rng) {
2936 return Unimplemented("Rng should be expanded for CPU.");
2937 }
2938
HandleRngGetAndUpdateState(HloInstruction * rng_state)2939 Status IrEmitter::HandleRngGetAndUpdateState(HloInstruction* rng_state) {
2940 VLOG(2) << "RngGetAndUpdateState: " << rng_state->ToString();
2941 llvm::Value* old_state = llvm_ir::RngGetAndUpdateState(
2942 Cast<HloRngGetAndUpdateStateInstruction>(rng_state)->delta(), module_,
2943 &b_);
2944
2945 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(rng_state));
2946 llvm::Value* address = GetEmittedValueFor(rng_state);
2947
2948 // The buffer has an array type while the value has a i128. Cast the
2949 // buffer to i128 type to store the value.
2950 address = BitCast(address, llvm::PointerType::get(
2951 old_state->getType()->getScalarType(),
2952 address->getType()->getPointerAddressSpace()));
2953 llvm::StoreInst* store = Store(old_state, address);
2954 store->setAlignment(llvm::Align(IrEmitter::MinimumAlignmentForPrimitiveType(
2955 rng_state->shape().element_type())));
2956
2957 return OkStatus();
2958 }
2959
FinishVisit(HloInstruction * root)2960 Status IrEmitter::FinishVisit(HloInstruction* root) {
2961 // When this method is called, we should have already emitted an IR value for
2962 // the root (return) op. The IR value holds the address of the buffer holding
2963 // the value. If the root is a constant or parameter, we perform a memcpy from
2964 // this buffer to the retval buffer of the computation. Otherwise, there's
2965 // nothing to do since the result was already written directly into the output
2966 // buffer.
2967 VLOG(2) << "FinishVisit root: " << root->ToString();
2968 if (root->opcode() == HloOpcode::kOutfeed) {
2969 VLOG(2) << " outfeed with value: "
2970 << llvm_ir::DumpToString(*GetEmittedValueFor(root->operand(0)));
2971 } else {
2972 VLOG(2) << " value: " << llvm_ir::DumpToString(*GetEmittedValueFor(root));
2973 }
2974
2975 auto record_complete_computation = [&](llvm::Value* prof_counter) {
2976 if (prof_counter) {
2977 profiling_state_.RecordCompleteComputation(&b_, prof_counter);
2978 }
2979 };
2980
2981 // For the entry computation this increment is cumulative of embedded
2982 // computations since it includes cycles spent in computations invoked by
2983 // While, Call etc.
2984 record_complete_computation(GetProfileCounterFor(*root->parent()));
2985 return OkStatus();
2986 }
2987
2988 template <typename T>
GetProfileCounterCommon(const T & hlo,const absl::flat_hash_map<const T *,int64_t> & profile_index_map)2989 llvm::Value* IrEmitter::GetProfileCounterCommon(
2990 const T& hlo,
2991 const absl::flat_hash_map<const T*, int64_t>& profile_index_map) {
2992 auto it = profile_index_map.find(&hlo);
2993 if (it == profile_index_map.end()) {
2994 return nullptr;
2995 }
2996
2997 int64_t prof_counter_idx = it->second;
2998 std::string counter_name = IrName("prof_counter", hlo.name());
2999 return GEP(b_.getInt64Ty(), GetProfileCountersArgument(),
3000 b_.getInt64(prof_counter_idx), counter_name);
3001 }
3002
GetProfileCounterFor(const HloInstruction & instruction)3003 llvm::Value* IrEmitter::GetProfileCounterFor(
3004 const HloInstruction& instruction) {
3005 return GetProfileCounterCommon<HloInstruction>(instruction,
3006 instruction_to_profile_idx_);
3007 }
3008
GetProfileCounterFor(const HloComputation & computation)3009 llvm::Value* IrEmitter::GetProfileCounterFor(
3010 const HloComputation& computation) {
3011 return GetProfileCounterCommon<HloComputation>(computation,
3012 computation_to_profile_idx_);
3013 }
3014
UpdateProfileCounter(llvm::IRBuilder<> * b,llvm::Value * prof_counter,llvm::Value * cycle_end,llvm::Value * cycle_start)3015 void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b,
3016 llvm::Value* prof_counter,
3017 llvm::Value* cycle_end,
3018 llvm::Value* cycle_start) {
3019 auto* cycle_diff = b->CreateSub(cycle_end, cycle_start);
3020 llvm::LoadInst* old_cycle_count = b->CreateLoad(
3021 llvm::cast<llvm::GetElementPtrInst>(prof_counter)->getSourceElementType(),
3022 prof_counter, "old_cycle_count");
3023 auto* new_cycle_count =
3024 b->CreateAdd(cycle_diff, old_cycle_count, "new_cycle_count");
3025 b->CreateStore(new_cycle_count, prof_counter);
3026 }
3027
ReadCycleCounter(llvm::IRBuilder<> * b)3028 llvm::Value* IrEmitter::ProfilingState::ReadCycleCounter(llvm::IRBuilder<>* b) {
3029 llvm::Module* module = b->GetInsertBlock()->getModule();
3030 if (!use_rdtscp_) {
3031 llvm::Function* func_llvm_readcyclecounter =
3032 llvm::Intrinsic::getDeclaration(module,
3033 llvm::Intrinsic::readcyclecounter);
3034 return b->CreateCall(func_llvm_readcyclecounter);
3035 }
3036 llvm::Function* func_llvm_x86_rdtscp =
3037 llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::x86_rdtscp);
3038 llvm::Value* rdtscp_call = b->CreateCall(func_llvm_x86_rdtscp);
3039 return b->CreateExtractValue(rdtscp_call, {0});
3040 }
3041
RecordCycleStart(llvm::IRBuilder<> * b,HloInstruction * hlo)3042 void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* b,
3043 HloInstruction* hlo) {
3044 auto* cycle_start = ReadCycleCounter(b);
3045 cycle_start->setName(IrName(hlo, "cycle_start"));
3046 cycle_starts_[hlo] = cycle_start;
3047 if (first_read_cycle_start_ == nullptr) {
3048 first_read_cycle_start_ = cycle_start;
3049 }
3050 }
3051
RecordCycleDelta(llvm::IRBuilder<> * b,HloInstruction * hlo,llvm::Value * prof_counter)3052 void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* b,
3053 HloInstruction* hlo,
3054 llvm::Value* prof_counter) {
3055 auto* cycle_end = ReadCycleCounter(b);
3056 cycle_end->setName(IrName(hlo, "cycle_end"));
3057 auto* cycle_start = cycle_starts_[hlo];
3058 UpdateProfileCounter(b, prof_counter, cycle_end, cycle_start);
3059 last_read_cycle_end_ = cycle_end;
3060 }
3061
RecordCompleteComputation(llvm::IRBuilder<> * b,llvm::Value * prof_counter)3062 void IrEmitter::ProfilingState::RecordCompleteComputation(
3063 llvm::IRBuilder<>* b, llvm::Value* prof_counter) {
3064 if (last_read_cycle_end_ && first_read_cycle_start_) {
3065 UpdateProfileCounter(b, prof_counter, last_read_cycle_end_,
3066 first_read_cycle_start_);
3067 }
3068 }
3069
EmitTracingStart(llvm::IRBuilder<> * b,HloInstruction * hlo,llvm::Value * run_options)3070 void IrEmitter::TracingState::EmitTracingStart(llvm::IRBuilder<>* b,
3071 HloInstruction* hlo,
3072 llvm::Value* run_options) {
3073 if (!enabled_) {
3074 return;
3075 }
3076
3077 llvm::Type* int8_ptr_type = b->getInt8Ty()->getPointerTo();
3078 llvm::Type* void_ptr_type =
3079 int8_ptr_type; // LLVM does not have a void*, we use an int8_t* instead.
3080 llvm::FunctionType* fn_type =
3081 llvm::FunctionType::get(b->getInt64Ty(), {void_ptr_type, int8_ptr_type},
3082 /*isVarArg=*/false);
3083
3084 llvm::Function* function = b->GetInsertBlock()->getParent();
3085 llvm::Module* module = function->getParent();
3086 const char* fn_name = runtime::kTracingStartSymbolName;
3087 llvm::FunctionCallee trace_func =
3088 module->getOrInsertFunction(fn_name, fn_type);
3089 if (auto* fn = llvm::dyn_cast<llvm::Function>(trace_func.getCallee())) {
3090 fn->setCallingConv(llvm::CallingConv::C);
3091 fn->setDoesNotThrow();
3092 fn->setOnlyAccessesArgMemory();
3093 }
3094 auto* hlo_name = b->CreateGlobalStringPtr(hlo->name());
3095 auto* activity_id =
3096 b->CreateCall(trace_func, {b->CreateBitCast(run_options, void_ptr_type),
3097 b->CreateBitCast(hlo_name, int8_ptr_type)});
3098 activity_id->setName(IrName(hlo, "activity_id"));
3099 activity_ids_[hlo] = activity_id;
3100 }
3101
EmitTracingEnd(llvm::IRBuilder<> * b,HloInstruction * hlo,llvm::Value * run_options)3102 void IrEmitter::TracingState::EmitTracingEnd(llvm::IRBuilder<>* b,
3103 HloInstruction* hlo,
3104 llvm::Value* run_options) {
3105 if (!enabled_) {
3106 return;
3107 }
3108
3109 llvm::Type* void_ptr_type =
3110 b->getInt8Ty()->getPointerTo(); // LLVM does not have a void*, we use an
3111 // int8_t* instead.
3112 llvm::FunctionType* fn_type =
3113 llvm::FunctionType::get(b->getVoidTy(), {void_ptr_type, b->getInt64Ty()},
3114 /*isVarArg=*/false);
3115
3116 llvm::Function* function = b->GetInsertBlock()->getParent();
3117 llvm::Module* module = function->getParent();
3118 const char* fn_name = runtime::kTracingEndSymbolName;
3119 llvm::FunctionCallee trace_func =
3120 module->getOrInsertFunction(fn_name, fn_type);
3121 if (auto* fn = llvm::dyn_cast<llvm::Function>(trace_func.getCallee())) {
3122 fn->setCallingConv(llvm::CallingConv::C);
3123 fn->setDoesNotThrow();
3124 fn->setOnlyAccessesArgMemory();
3125 }
3126 auto* activity_id = activity_ids_.at(hlo);
3127 b->CreateCall(trace_func,
3128 {b->CreateBitCast(run_options, void_ptr_type), activity_id});
3129 }
3130
3131 namespace {
IsHloVeryCheap(const HloInstruction * hlo)3132 bool IsHloVeryCheap(const HloInstruction* hlo) {
3133 return hlo->opcode() == HloOpcode::kBitcast ||
3134 hlo->opcode() == HloOpcode::kTuple ||
3135 hlo->opcode() == HloOpcode::kGetTupleElement ||
3136 hlo->opcode() == HloOpcode::kParameter ||
3137 hlo->opcode() == HloOpcode::kConstant ||
3138 hlo->opcode() == HloOpcode::kReplicaId;
3139 }
3140 } // namespace
3141
Preprocess(HloInstruction * hlo)3142 Status IrEmitter::Preprocess(HloInstruction* hlo) {
3143 VLOG(3) << "Visiting: " << hlo->ToString();
3144 // When profiling is enabled, trace the same HLOs that the profiler does.
3145 if (instruction_to_profile_idx_.count(hlo) ||
3146 (hlo_module_config_.cpu_traceme_enabled() && !IsHloVeryCheap(hlo))) {
3147 tracing_state_.EmitTracingStart(&b_, hlo,
3148 GetExecutableRunOptionsArgument());
3149 profiling_state_.RecordCycleStart(&b_, hlo);
3150 }
3151 return OkStatus();
3152 }
3153
Postprocess(HloInstruction * hlo)3154 Status IrEmitter::Postprocess(HloInstruction* hlo) {
3155 if (auto* prof_counter = GetProfileCounterFor(*hlo)) {
3156 profiling_state_.RecordCycleDelta(&b_, hlo, prof_counter);
3157 }
3158 // When profiling is enabled, trace the same HLOs that the profiler does.
3159 if (instruction_to_profile_idx_.count(hlo) ||
3160 (hlo_module_config_.cpu_traceme_enabled() && !IsHloVeryCheap(hlo))) {
3161 tracing_state_.EmitTracingEnd(&b_, hlo, GetExecutableRunOptionsArgument());
3162 }
3163 return OkStatus();
3164 }
3165
GetIrArrayFor(const HloInstruction * hlo)3166 llvm_ir::IrArray IrEmitter::GetIrArrayFor(const HloInstruction* hlo) {
3167 llvm::Value* value_for_op = GetEmittedValueFor(hlo);
3168
3169 llvm::Type* ir_type = IrShapeType(hlo->shape());
3170 llvm_ir::IrArray array(value_for_op, ir_type, hlo->shape());
3171 AddAliasingInformationToIrArray(*hlo, &array);
3172 return array;
3173 }
3174
GetIrArraysForOperandsOf(const HloInstruction * hlo)3175 std::vector<llvm_ir::IrArray> IrEmitter::GetIrArraysForOperandsOf(
3176 const HloInstruction* hlo) {
3177 std::vector<llvm_ir::IrArray> arrays;
3178 std::transform(
3179 hlo->operands().begin(), hlo->operands().end(),
3180 std::back_inserter(arrays),
3181 [&](const HloInstruction* operand) { return GetIrArrayFor(operand); });
3182 return arrays;
3183 }
3184
GetEmittedValueFor(const HloInstruction * hlo)3185 llvm::Value* IrEmitter::GetEmittedValueFor(const HloInstruction* hlo) {
3186 auto it = emitted_value_.find(hlo);
3187 if (it == emitted_value_.end()) {
3188 LOG(FATAL) << "could not find emitted value for: " << hlo->ToString();
3189 }
3190 return it->second;
3191 }
3192
IrShapeType(const Shape & shape)3193 llvm::Type* IrEmitter::IrShapeType(const Shape& shape) {
3194 return llvm_ir::ShapeToIrType(shape, module_);
3195 }
3196
GetProfileCountersArgument()3197 llvm::Value* IrEmitter::GetProfileCountersArgument() {
3198 return compute_function_->profile_counters_arg();
3199 }
3200
GetStatusArgument()3201 llvm::Value* IrEmitter::GetStatusArgument() {
3202 return compute_function_->status_arg();
3203 }
3204
GetBufferTableArgument()3205 llvm::Value* IrEmitter::GetBufferTableArgument() {
3206 return compute_function_->buffer_table_arg();
3207 }
3208
GetExecutableRunOptionsArgument()3209 llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() {
3210 return compute_function_->exec_run_options_arg();
3211 }
3212
GetReturnBlock()3213 llvm::BasicBlock* IrEmitter::GetReturnBlock() {
3214 return compute_function_->return_block();
3215 }
3216
EmitEarlyReturnIfErrorStatus()3217 void IrEmitter::EmitEarlyReturnIfErrorStatus() {
3218 // Use the runtime helper to get the success/failure state as a boolean.
3219 llvm::Value* succeeded =
3220 EmitCallToFunc(runtime::kStatusIsSuccessSymbolName, {GetStatusArgument()},
3221 b_.getInt1Ty(), /*does_not_throw=*/true,
3222 /*only_accesses_arg_memory=*/true);
3223 llvm_ir::EmitEarlyReturn(succeeded, &b_, GetReturnBlock());
3224 }
3225
EmitThreadLocalBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)3226 llvm::Value* IrEmitter::EmitThreadLocalBufferPointer(
3227 const BufferAllocation::Slice& slice, const Shape& target_shape) {
3228 const BufferAllocation& allocation = *slice.allocation();
3229 llvm::Value* tempbuf_address = [&]() -> llvm::Value* {
3230 auto param_it =
3231 computation_parameter_allocations_.find(slice.allocation()->index());
3232 if (param_it != computation_parameter_allocations_.end()) {
3233 int64_t param_number = param_it->second;
3234 // We have to access the parameter at offset param_number in the params
3235 // array. The code generated here is equivalent to this C code:
3236 //
3237 // i8* param_address_untyped = params[param_number];
3238 // Param* param_address_typed = (Param*)param_address_untyped;
3239 //
3240 // Where Param is the actual element type of the underlying buffer (for
3241 // example, float for an XLA F32 element type).
3242 llvm::Value* params = compute_function_->parameters_arg();
3243 llvm::Value* param_address_offset = llvm_ir::EmitBufferIndexingGEP(
3244 params, b_.getInt8PtrTy(), param_number, &b_);
3245 llvm::LoadInst* param_address_untyped =
3246 Load(b_.getInt8PtrTy(), param_address_offset);
3247
3248 if (!target_shape.IsOpaque()) {
3249 AttachAlignmentMetadataForLoad(param_address_untyped, target_shape);
3250 AttachDereferenceableMetadataForLoad(param_address_untyped,
3251 target_shape);
3252 }
3253 return param_address_untyped;
3254 }
3255
3256 // Thread-local allocations should only be assigned a single buffer.
3257 const auto& assigned_buffers = allocation.assigned_buffers();
3258 CHECK_EQ(1, assigned_buffers.size());
3259 const Shape& shape = assigned_buffers.begin()->first->shape();
3260
3261 std::pair<llvm::Function*, BufferAllocation::Slice> key = {
3262 compute_function_->function(), slice};
3263 auto buf_it = thread_local_buffers_.find(key);
3264 if (buf_it == thread_local_buffers_.end()) {
3265 llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry(
3266 IrShapeType(shape), absl::StrCat("thread_local", slice.ToString()),
3267 &b_, MinimumAlignmentForShape(target_shape));
3268 auto it_inserted_pair = thread_local_buffers_.insert({key, buffer});
3269 CHECK(it_inserted_pair.second);
3270 buf_it = it_inserted_pair.first;
3271 }
3272 return buf_it->second;
3273 }();
3274 return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo());
3275 }
3276
EmitGlobalBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)3277 llvm::Value* IrEmitter::EmitGlobalBufferPointer(
3278 const BufferAllocation::Slice& slice, const Shape& target_shape) {
3279 const BufferAllocation& allocation = *slice.allocation();
3280 llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
3281 GetBufferTableArgument(), b_.getInt8PtrTy(), slice.index(), &b_);
3282 llvm::LoadInst* tempbuf_address_base =
3283 Load(b_.getInt8PtrTy(), tempbuf_address_ptr);
3284 if (hlo_module_config_.debug_options()
3285 .xla_llvm_enable_invariant_load_metadata()) {
3286 tempbuf_address_base->setMetadata(
3287 llvm::LLVMContext::MD_invariant_load,
3288 llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{}));
3289 }
3290 AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size());
3291 AttachDereferenceableMetadataForLoad(tempbuf_address_base, allocation.size());
3292
3293 llvm::Value* tempbuf_address_untyped = tempbuf_address_base;
3294 if (slice.offset() > 0) {
3295 // Adjust the address to account for the slice offset.
3296 tempbuf_address_untyped = InBoundsGEP(b_.getInt8Ty(), tempbuf_address_base,
3297 b_.getInt64(slice.offset()));
3298 }
3299 return BitCast(tempbuf_address_untyped,
3300 IrShapeType(target_shape)->getPointerTo());
3301 }
3302
EmitBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)3303 llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice,
3304 const Shape& target_shape) {
3305 if (slice.allocation()->is_thread_local()) {
3306 return EmitThreadLocalBufferPointer(slice, target_shape);
3307 } else if (slice.allocation()->is_constant()) {
3308 return BitCast(
3309 FindOrDie(constant_buffer_to_global_, slice.allocation()->index()),
3310 IrShapeType(target_shape)->getPointerTo());
3311 } else {
3312 return EmitGlobalBufferPointer(slice, target_shape);
3313 }
3314 }
3315
EmitTargetAddressForOp(const HloInstruction * op)3316 Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
3317 const Shape& target_shape = op->shape();
3318 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
3319 assignment_.GetUniqueTopLevelSlice(op));
3320 llvm::Value* addr = EmitBufferPointer(slice, target_shape);
3321 addr->setName(IrName(op));
3322 emitted_value_[op] = addr;
3323 return OkStatus();
3324 }
3325
EmitTargetElementLoop(HloInstruction * target_op,const llvm_ir::ElementGenerator & element_generator)3326 Status IrEmitter::EmitTargetElementLoop(
3327 HloInstruction* target_op,
3328 const llvm_ir::ElementGenerator& element_generator) {
3329 return EmitTargetElementLoop(target_op, /*desc=*/"", element_generator);
3330 }
3331
EmitTargetElementLoop(HloInstruction * target_op,absl::string_view desc,const llvm_ir::ElementGenerator & element_generator)3332 Status IrEmitter::EmitTargetElementLoop(
3333 HloInstruction* target_op, absl::string_view desc,
3334 const llvm_ir::ElementGenerator& element_generator) {
3335 VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString();
3336
3337 const Shape& target_shape = target_op->shape();
3338 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op));
3339 llvm_ir::IrArray target_array = GetIrArrayFor(target_op);
3340
3341 if (target_shape.IsTuple() &&
3342 (target_op->opcode() == HloOpcode::kFusion ||
3343 target_op->opcode() == HloOpcode::kReduce ||
3344 target_op->opcode() == HloOpcode::kReduceWindow)) {
3345 // For multiple outputs fusion, we need to emit each operand and the root.
3346 TF_RET_CHECK(num_dynamic_loop_bounds_ == 0);
3347 std::vector<llvm_ir::IrArray> output_arrays;
3348 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(target_shape); ++i) {
3349 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
3350 assignment_.GetUniqueSlice(target_op, {i}));
3351 const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i});
3352 llvm::Value* op_target_address = EmitBufferPointer(slice, element_shape);
3353 llvm::Type* op_target_type = IrShapeType(element_shape);
3354 output_arrays.push_back(
3355 llvm_ir::IrArray(op_target_address, op_target_type, element_shape));
3356 }
3357 TF_RETURN_IF_ERROR(
3358 llvm_ir::LoopEmitter(element_generator, output_arrays, &b_)
3359 .EmitLoop(IrName(target_op)));
3360
3361 std::vector<llvm::Value*> tuple_operand_ptrs;
3362 for (int64_t i = 0; i < output_arrays.size(); ++i) {
3363 tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
3364 }
3365 llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &b_);
3366
3367 } else {
3368 if (ShouldEmitParallelLoopFor(*target_op)) {
3369 // Emit code to read dynamic loop bounds from compute function argument.
3370 std::vector<std::pair<llvm::Value*, llvm::Value*>> dynamic_loop_bounds =
3371 compute_function_->GetDynamicLoopBounds();
3372 // Emit parallel loop with dynamic loop bounds for most-major dimensions.
3373 TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array,
3374 &dynamic_loop_bounds, &b_)
3375 .EmitLoop(IrName(target_op)));
3376 } else {
3377 TF_RETURN_IF_ERROR(
3378 llvm_ir::LoopEmitter(element_generator, target_array, &b_)
3379 .EmitLoop(IrName(target_op)));
3380 }
3381 }
3382 return OkStatus();
3383 }
3384
EmitMemcpy(const HloInstruction & source,const HloInstruction & destination)3385 Status IrEmitter::EmitMemcpy(const HloInstruction& source,
3386 const HloInstruction& destination) {
3387 llvm::Value* source_value = GetEmittedValueFor(&source);
3388 llvm::Value* destination_value = GetEmittedValueFor(&destination);
3389 int64_t source_size = ByteSizeOf(source.shape());
3390 // TODO(b/63762267): Be more aggressive about specifying alignment.
3391 MemCpy(destination_value, /*DstAlign=*/llvm::Align(1), source_value,
3392 /*SrcAlign=*/llvm::Align(1), source_size);
3393 return OkStatus();
3394 }
3395
ElementTypesSameAndSupported(const HloInstruction & instruction,absl::Span<const HloInstruction * const> operands,absl::Span<const PrimitiveType> supported_types)3396 Status IrEmitter::ElementTypesSameAndSupported(
3397 const HloInstruction& instruction,
3398 absl::Span<const HloInstruction* const> operands,
3399 absl::Span<const PrimitiveType> supported_types) {
3400 for (auto operand : operands) {
3401 TF_RET_CHECK(
3402 ShapeUtil::SameElementType(operands[0]->shape(), operand->shape()));
3403 }
3404
3405 TF_RET_CHECK(!operands.empty());
3406 PrimitiveType primitive_type = operands[0]->shape().element_type();
3407 if (!absl::c_linear_search(supported_types, primitive_type)) {
3408 return Unimplemented("unsupported operand type %s in op %s",
3409 PrimitiveType_Name(primitive_type),
3410 HloOpcodeString(instruction.opcode()));
3411 }
3412 return OkStatus();
3413 }
3414
DefaultAction(HloInstruction * hlo)3415 Status IrEmitter::DefaultAction(HloInstruction* hlo) {
3416 ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
3417 for (const HloInstruction* operand : hlo->operands()) {
3418 operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
3419 return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_);
3420 };
3421 }
3422 CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
3423 return EmitTargetElementLoop(
3424 hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
3425 }
3426
EmitScalarReturningThreadLocalCall(const HloComputation & callee,absl::Span<llvm::Value * const> parameters,absl::string_view name)3427 llvm::Value* IrEmitter::EmitScalarReturningThreadLocalCall(
3428 const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
3429 absl::string_view name) {
3430 std::vector<llvm::Value*> return_value =
3431 EmitThreadLocalCall(callee, parameters, name, /*is_reducer=*/false);
3432 CHECK_EQ(return_value.size(), 1);
3433 return return_value[0];
3434 }
3435
EmitThreadLocalCall(const HloComputation & callee,absl::Span<llvm::Value * const> parameters,absl::string_view name,bool is_reducer)3436 std::vector<llvm::Value*> IrEmitter::EmitThreadLocalCall(
3437 const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
3438 absl::string_view name, bool is_reducer) {
3439 CHECK(absl::c_binary_search(thread_local_computations_, &callee));
3440 const Shape& return_shape = callee.root_instruction()->shape();
3441 bool is_scalar_return = ShapeUtil::IsScalar(return_shape);
3442 bool is_tuple_of_scalars_return =
3443 return_shape.IsTuple() &&
3444 absl::c_all_of(return_shape.tuple_shapes(), [&](const Shape& shape) {
3445 return ShapeUtil::IsScalar(shape);
3446 });
3447 CHECK(is_scalar_return || is_tuple_of_scalars_return);
3448
3449 std::vector<llvm::Value*> parameter_addrs;
3450 for (llvm::Value* parameter : parameters) {
3451 CHECK(!parameter->getType()->isPointerTy());
3452 llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry(
3453 parameter->getType(), "arg_addr", &b_);
3454 Store(parameter, parameter_addr);
3455 parameter_addrs.push_back(parameter_addr);
3456 }
3457
3458 llvm::Type* return_value_buffer_type =
3459 llvm_ir::ShapeToIrType(return_shape, module_);
3460 std::string retval_alloca_name = absl::StrCat(name, "_return_value_addr");
3461 int retval_alignment =
3462 is_scalar_return
3463 ? MinimumAlignmentForPrimitiveType(return_shape.element_type())
3464 : 0;
3465 llvm::AllocaInst* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
3466 return_value_buffer_type, retval_alloca_name, &b_, retval_alignment);
3467
3468 std::vector<llvm::Value*> allocas_for_returned_scalars;
3469 if (is_scalar_return) {
3470 allocas_for_returned_scalars.push_back(return_value_buffer);
3471 } else {
3472 constexpr int max_tuple_size = 1000;
3473 CHECK_LT(return_shape.tuple_shapes_size(), max_tuple_size)
3474 << "Multivalue function can not return more than 1000 elements to avoid"
3475 << " stack smashing";
3476 allocas_for_returned_scalars =
3477 llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_);
3478 llvm_ir::IrArray tuple_array(return_value_buffer, return_value_buffer_type,
3479 return_shape);
3480
3481 EmitTuple(tuple_array, allocas_for_returned_scalars, &b_);
3482 }
3483
3484 Call(
3485 FindOrDie(emitted_functions_,
3486 ComputationToEmit{&callee, allow_reassociation_ || is_reducer}),
3487 GetArrayFunctionCallArguments(
3488 parameter_addrs, &b_, name,
3489 /*return_value_buffer=*/return_value_buffer,
3490 /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
3491 /*buffer_table_arg=*/
3492 llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
3493 /*status_arg=*/GetStatusArgument(),
3494 /*profile_counters_arg=*/GetProfileCountersArgument()));
3495
3496 if (ComputationTransitivelyContainsCustomCall(&callee)) {
3497 EmitEarlyReturnIfErrorStatus();
3498 }
3499
3500 std::vector<llvm::Value*> returned_scalars;
3501 returned_scalars.reserve(allocas_for_returned_scalars.size());
3502 for (llvm::Value* addr : allocas_for_returned_scalars) {
3503 returned_scalars.push_back(
3504 Load(llvm::cast<llvm::AllocaInst>(addr)->getAllocatedType(), addr));
3505 }
3506 return returned_scalars;
3507 }
3508
EmitGlobalCall(const HloComputation & callee,absl::string_view name)3509 void IrEmitter::EmitGlobalCall(const HloComputation& callee,
3510 absl::string_view name) {
3511 CHECK(absl::c_binary_search(global_computations_, &callee));
3512
3513 Call(FindOrDie(emitted_functions_,
3514 ComputationToEmit{&callee, allow_reassociation_}),
3515 GetArrayFunctionCallArguments(
3516 /*parameter_addresses=*/{}, &b_, name,
3517 /*return_value_buffer=*/
3518 llvm::Constant::getNullValue(b_.getInt8PtrTy()),
3519 /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
3520 /*buffer_table_arg=*/GetBufferTableArgument(),
3521 /*status_arg=*/GetStatusArgument(),
3522 /*profile_counters_arg=*/GetProfileCountersArgument()));
3523
3524 if (ComputationTransitivelyContainsCustomCall(&callee)) {
3525 EmitEarlyReturnIfErrorStatus();
3526 }
3527 }
3528
GetBufferForGlobalCallReturnValue(const HloComputation & callee)3529 llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
3530 const HloComputation& callee) {
3531 const HloInstruction* root_inst = callee.root_instruction();
3532 if (root_inst->opcode() == HloOpcode::kOutfeed) {
3533 return llvm::Constant::getNullValue(b_.getInt8PtrTy());
3534 }
3535
3536 const BufferAllocation::Slice root_buffer =
3537 assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie();
3538 return EmitBufferPointer(root_buffer, root_inst->shape());
3539 }
3540
BindFusionArguments(const HloInstruction * fusion,FusedIrEmitter * fused_emitter)3541 void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
3542 FusedIrEmitter* fused_emitter) {
3543 for (int i = 0; i < fusion->operand_count(); i++) {
3544 const HloInstruction* operand = fusion->operand(i);
3545 fused_emitter->BindGenerator(
3546 *fusion->fused_parameter(i),
3547 [this, operand](llvm_ir::IrArray::Index index) {
3548 return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_);
3549 });
3550 }
3551 }
3552
3553 } // namespace cpu
3554 } // namespace xla
3555