1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <vector>
21
22 #include "absl/base/casts.h"
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/ADT/Triple.h"
28 #include "llvm/IR/DerivedTypes.h"
29 #include "llvm/IR/GlobalValue.h"
30 #include "llvm/IR/GlobalVariable.h"
31 #include "llvm/IR/MDBuilder.h"
32 #include "llvm/IR/Operator.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Target/TargetOptions.h"
35 #include "llvm/Transforms/Utils/Cloning.h"
36 #include "tensorflow/compiler/xla/debug_options_flags.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
40 #include "tensorflow/compiler/xla/service/dump.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_type_conversion_util.h"
42 #include "tensorflow/compiler/xla/service/name_uniquer.h"
43 #include "tensorflow/compiler/xla/shape_util.h"
44 #include "tensorflow/compiler/xla/types.h"
45 #include "tensorflow/compiler/xla/util.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/platform/env.h"
48 #include "tensorflow/core/platform/logging.h"
49
50 namespace xla {
51 namespace llvm_ir {
52
53 namespace {
54
55 // Note, this function is only useful in an insertion context; in a global
56 // (e.g. constants) context it will CHECK fail.
ModuleFromIRBuilder(llvm::IRBuilder<> * b)57 llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* b) {
58 auto block = CHECK_NOTNULL(b->GetInsertBlock());
59 auto fn = CHECK_NOTNULL(block->getParent());
60 auto module = CHECK_NOTNULL(fn->getParent());
61 return module;
62 }
63
64 } // namespace
65
DropConstantInitializers(const llvm::Module & module)66 std::unique_ptr<llvm::Module> DropConstantInitializers(
67 const llvm::Module& module) {
68 std::unique_ptr<llvm::Module> cloned_module = CloneModule(module);
69 for (llvm::GlobalVariable& global_var : cloned_module->globals()) {
70 global_var.setInitializer(nullptr);
71 global_var.setLinkage(llvm::GlobalValue::LinkageTypes::ExternalLinkage);
72 }
73 return cloned_module;
74 }
75
DumpModuleToString(const llvm::Module & module)76 std::string DumpModuleToString(const llvm::Module& module) {
77 std::string buffer_string;
78 llvm::raw_string_ostream ostream(buffer_string);
79 module.print(ostream, nullptr);
80 ostream.flush();
81 return buffer_string;
82 }
83
EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id,absl::Span<llvm::Value * const> operands,absl::Span<llvm::Type * const> overloaded_types,llvm::IRBuilder<> * b,absl::string_view name)84 llvm::CallInst* EmitCallToIntrinsic(
85 llvm::Intrinsic::ID intrinsic_id, absl::Span<llvm::Value* const> operands,
86 absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b,
87 absl::string_view name) {
88 llvm::Module* module = ModuleFromIRBuilder(b);
89 llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
90 module, intrinsic_id, AsArrayRef(overloaded_types));
91 return b->CreateCall(intrinsic, AsArrayRef(operands), name.data());
92 }
93
EmitFloatMax(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b,bool enable_fast_min_max,absl::string_view name)94 llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
95 llvm::IRBuilder<>* b, bool enable_fast_min_max,
96 absl::string_view name) {
97 if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
98 auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value);
99 return b->CreateSelect(cmp, lhs_value, rhs_value, name.data());
100 } else {
101 auto cmp_ge = b->CreateFCmpOGE(lhs_value, rhs_value);
102 auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
103 auto sel_lhs = b->CreateOr(cmp_ge, lhs_is_nan);
104 return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data());
105 }
106 }
107
EmitFloatMin(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b,bool enable_fast_min_max,absl::string_view name)108 llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
109 llvm::IRBuilder<>* b, bool enable_fast_min_max,
110 absl::string_view name) {
111 if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
112 auto cmp = b->CreateFCmpULE(lhs_value, rhs_value);
113 return b->CreateSelect(cmp, lhs_value, rhs_value, name.data());
114 } else {
115 auto cmp_le = b->CreateFCmpOLE(lhs_value, rhs_value);
116 auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
117 auto sel_lhs = b->CreateOr(cmp_le, lhs_is_nan);
118 return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data());
119 }
120 }
121
EmitBufferIndexingGEP(llvm::Value * array,llvm::Type * element_type,llvm::Value * index,llvm::IRBuilder<> * b)122 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Type* element_type,
123 llvm::Value* index, llvm::IRBuilder<>* b) {
124 llvm::Type* array_type = array->getType();
125 CHECK(array_type->isPointerTy());
126 llvm::PointerType* array_type_as_pointer =
127 llvm::cast<llvm::PointerType>(array_type);
128 CHECK(array_type_as_pointer->isOpaqueOrPointeeTypeMatches(element_type));
129 VLOG(2) << "EmitBufferIndexingGEP with type="
130 << llvm_ir::DumpToString(*array_type)
131 << " array=" << llvm_ir::DumpToString(*array)
132 << " index=" << llvm_ir::DumpToString(*index);
133
134 return b->CreateInBoundsGEP(
135 element_type, array,
136 llvm::isa<llvm::GlobalVariable>(array)
137 ? llvm::ArrayRef<llvm::Value*>({b->getInt64(0), index})
138 : index);
139 }
140
EmitBufferIndexingGEP(llvm::Value * array,llvm::Type * element_type,int64_t index,llvm::IRBuilder<> * b)141 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Type* element_type,
142 int64_t index, llvm::IRBuilder<>* b) {
143 return EmitBufferIndexingGEP(array, element_type, b->getInt64(index), b);
144 }
145
PrimitiveTypeToIrType(PrimitiveType element_type,llvm::Module * module)146 llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
147 llvm::Module* module) {
148 switch (element_type) {
149 case PRED:
150 case S8:
151 case U8:
152 return llvm::Type::getInt8Ty(module->getContext());
153 case S16:
154 case U16:
155 case BF16:
156 // For BF16 we just need some type that is 16 bits wide so that it will
157 // take up the right amount of space in memory. LLVM does not have a BF16
158 // type (the LLVM half type is IEEE 16 bit floating point, not bfloat), so
159 // we can't map it directly to an LLVM type. We will not map a BF16
160 // addition to an addition on this type (int16_t) - this is just the type
161 // used for storage.
162 return llvm::Type::getInt16Ty(module->getContext());
163 case F16:
164 return llvm::Type::getHalfTy(module->getContext());
165 case S32:
166 case U32:
167 return llvm::Type::getInt32Ty(module->getContext());
168 case S64:
169 case U64:
170 return llvm::Type::getInt64Ty(module->getContext());
171 case F32:
172 return llvm::Type::getFloatTy(module->getContext());
173 case F64:
174 return llvm::Type::getDoubleTy(module->getContext());
175 case C64: {
176 auto cplx_t =
177 llvm::StructType::getTypeByName(module->getContext(), "complex64");
178 if (cplx_t == nullptr) {
179 // C++ standard dictates the memory layout of std::complex is contiguous
180 // real followed by imaginary. C++11 section 26.4 [complex.numbers]:
181 // If z is an lvalue expression of type cv std::complex<T> then the
182 // expression reinterpret_cast<cv T(&)[2]>(z) shall be well-formed,
183 // reinterpret_cast<cv T(&)[2]>(z)[0] shall designate the real part of
184 // z, and reinterpret_cast<cv T(&)[2]>(z)[1] shall designate the
185 // imaginary part of z.
186 return llvm::StructType::create(
187 {llvm::Type::getFloatTy(module->getContext()),
188 llvm::Type::getFloatTy(module->getContext())},
189 "complex64", /*isPacked=*/true);
190 }
191 return cplx_t;
192 }
193 case C128: {
194 auto cplx_t =
195 llvm::StructType::getTypeByName(module->getContext(), "complex128");
196 if (cplx_t == nullptr) {
197 return llvm::StructType::create(
198 {llvm::Type::getDoubleTy(module->getContext()),
199 llvm::Type::getDoubleTy(module->getContext())},
200 "complex128", /*isPacked=*/true);
201 }
202 return cplx_t;
203 } // A Tuple contains an array of pointers. Use i8*.
204 case TUPLE:
205 // An Opaque is like a void*, use i8*.
206 case OPAQUE_TYPE:
207 return llvm::Type::getInt8PtrTy(module->getContext());
208 case TOKEN:
209 // Tokens do not have a physical representation, but the compiler needs
210 // some placeholder type, so use int8_t*.
211 return llvm::Type::getInt8PtrTy(module->getContext());
212 default:
213 LOG(FATAL) << "unsupported type " << element_type;
214 }
215 }
216
GetSizeInBits(llvm::Type * type)217 int GetSizeInBits(llvm::Type* type) {
218 const llvm::StructType* struct_ty = llvm::dyn_cast<llvm::StructType>(type);
219 if (struct_ty) {
220 CHECK(struct_ty->isPacked());
221 int bits = 0;
222 for (auto element_type : struct_ty->elements()) {
223 bits += GetSizeInBits(element_type);
224 }
225 return bits;
226 }
227 int bits = type->getPrimitiveSizeInBits();
228 CHECK_GT(bits, 0) << "type is not sized";
229 return bits;
230 }
231
ShapeToIrType(const Shape & shape,llvm::Module * module)232 llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) {
233 llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module);
234 if (shape.IsTuple()) {
235 // A tuple buffer is an array of pointers.
236 result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size());
237 } else if (shape.IsArray()) {
238 for (int64_t dimension : LayoutUtil::MinorToMajor(shape)) {
239 result_type =
240 llvm::ArrayType::get(result_type, shape.dimensions(dimension));
241 }
242 }
243 return result_type;
244 }
245
EncodeSelfDescribingShapeConstant(const Shape & shape,int32_t * shape_size,llvm::IRBuilder<> * b)246 StatusOr<llvm::Value*> EncodeSelfDescribingShapeConstant(const Shape& shape,
247 int32_t* shape_size,
248 llvm::IRBuilder<>* b) {
249 std::string encoded_shape = shape.SerializeAsString();
250 if (encoded_shape.size() > std::numeric_limits<int32_t>::max()) {
251 return InternalError("Encoded shape size exceeded int32_t size limit.");
252 }
253 *shape_size = static_cast<int32_t>(encoded_shape.size());
254 return b->CreateGlobalStringPtr(encoded_shape);
255 }
256
ConvertLiteralToIrConstant(const Literal & literal,llvm::Module * module)257 llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
258 llvm::Module* module) {
259 const char* data = static_cast<const char*>(literal.untyped_data());
260 CHECK_EQ(module->getDataLayout().isLittleEndian(),
261 tensorflow::port::kLittleEndian);
262 return llvm::ConstantDataArray::getString(
263 module->getContext(), llvm::StringRef(data, literal.size_bytes()),
264 /*AddNull=*/false);
265 }
266
AllocateSharedMemoryTile(llvm::Module * module,llvm::Type * tile_type,absl::string_view name)267 llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module,
268 llvm::Type* tile_type,
269 absl::string_view name) {
270 // Both AMDGPU and NVPTX use the same address space for shared memory.
271 const int kGPUSharedMemoryAddrSpace = 3;
272 return new llvm::GlobalVariable(
273 *module, tile_type,
274 /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
275 llvm::UndefValue::get(tile_type), AsStringRef(name), nullptr,
276 llvm::GlobalValue::NotThreadLocal, kGPUSharedMemoryAddrSpace);
277 }
278
EmitAllocaAtFunctionEntry(llvm::Type * type,absl::string_view name,llvm::IRBuilder<> * b,int alignment)279 llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
280 absl::string_view name,
281 llvm::IRBuilder<>* b,
282 int alignment) {
283 return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, b, alignment);
284 }
285
EmitAllocaAtFunctionEntryWithCount(llvm::Type * type,llvm::Value * element_count,absl::string_view name,llvm::IRBuilder<> * b,int alignment)286 llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type,
287 llvm::Value* element_count,
288 absl::string_view name,
289 llvm::IRBuilder<>* b,
290 int alignment) {
291 llvm::IRBuilder<>::InsertPointGuard guard(*b);
292 llvm::Function* function = b->GetInsertBlock()->getParent();
293 b->SetInsertPoint(&function->getEntryBlock(),
294 function->getEntryBlock().getFirstInsertionPt());
295 llvm::AllocaInst* alloca =
296 b->CreateAlloca(type, element_count, AsStringRef(name));
297 if (alignment != 0) {
298 alloca->setAlignment(llvm::Align(alignment));
299 }
300 return alloca;
301 }
302
CreateBasicBlock(llvm::BasicBlock * insert_before,absl::string_view name,llvm::IRBuilder<> * b)303 llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
304 absl::string_view name,
305 llvm::IRBuilder<>* b) {
306 return llvm::BasicBlock::Create(
307 /*Context=*/b->getContext(),
308 /*Name=*/AsStringRef(name),
309 /*Parent=*/b->GetInsertBlock()->getParent(),
310 /*InsertBefore*/ insert_before);
311 }
312
EmitIfThenElse(llvm::Value * condition,absl::string_view name,llvm::IRBuilder<> * b,bool emit_else)313 LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
314 llvm::IRBuilder<>* b, bool emit_else) {
315 llvm_ir::LlvmIfData if_data;
316 if_data.if_block = b->GetInsertBlock();
317 if_data.true_block =
318 CreateBasicBlock(nullptr, absl::StrCat(name, "-true"), b);
319 if_data.false_block =
320 emit_else ? CreateBasicBlock(nullptr, absl::StrCat(name, "-false"), b)
321 : nullptr;
322
323 // Add a terminator to the if block, if necessary.
324 if (if_data.if_block->getTerminator() == nullptr) {
325 b->SetInsertPoint(if_data.if_block);
326 if_data.after_block =
327 CreateBasicBlock(nullptr, absl::StrCat(name, "-after"), b);
328 b->CreateBr(if_data.after_block);
329 } else {
330 if_data.after_block = if_data.if_block->splitBasicBlock(
331 b->GetInsertPoint(), absl::StrCat(name, "-after"));
332 }
333
334 // Our basic block should now end with an unconditional branch. Remove it;
335 // we're going to replace it with a conditional branch.
336 if_data.if_block->getTerminator()->eraseFromParent();
337
338 b->SetInsertPoint(if_data.if_block);
339 b->CreateCondBr(condition, if_data.true_block,
340 emit_else ? if_data.false_block : if_data.after_block);
341
342 b->SetInsertPoint(if_data.true_block);
343 b->CreateBr(if_data.after_block);
344
345 if (emit_else) {
346 b->SetInsertPoint(if_data.false_block);
347 b->CreateBr(if_data.after_block);
348 }
349
350 b->SetInsertPoint(if_data.after_block,
351 if_data.after_block->getFirstInsertionPt());
352
353 return if_data;
354 }
355
EmitComparison(llvm::CmpInst::Predicate predicate,llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b,absl::string_view name)356 llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
357 llvm::Value* lhs_value, llvm::Value* rhs_value,
358 llvm::IRBuilder<>* b, absl::string_view name) {
359 llvm::Value* comparison_result;
360 if (lhs_value->getType()->isIntegerTy()) {
361 comparison_result =
362 b->CreateICmp(predicate, lhs_value, rhs_value, name.data());
363 } else {
364 comparison_result =
365 b->CreateFCmp(predicate, lhs_value, rhs_value, name.data());
366 }
367 // comparison_result is i1, but the NVPTX codegen incorrectly lowers i1
368 // arrays. So we extend it to i8 so that it's addressable.
369 return b->CreateZExt(comparison_result, llvm_ir::PrimitiveTypeToIrType(
370 PRED, ModuleFromIRBuilder(b)));
371 }
372
373 // Internal helper that is called from emitted code to log an int64_t value with
374 // a tag.
LogS64(const char * tag,int64_t value)375 static void LogS64(const char* tag, int64_t value) {
376 LOG(INFO) << tag << " (int64_t): " << value;
377 }
378
EmitLogging(const char * tag,llvm::Value * value,llvm::IRBuilder<> * b)379 void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* b) {
380 llvm::FunctionType* log_function_type = llvm::FunctionType::get(
381 b->getVoidTy(), {b->getInt64Ty(), b->getInt64Ty()}, /*isVarArg=*/false);
382 b->CreateCall(log_function_type,
383 b->CreateIntToPtr(b->getInt64(absl::bit_cast<int64_t>(&LogS64)),
384 log_function_type->getPointerTo()),
385 {b->getInt64(absl::bit_cast<int64_t>(tag)), value});
386 }
387
SetAlignmentMetadataForLoad(llvm::LoadInst * load,uint64_t alignment)388 void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) {
389 llvm::LLVMContext& context = load->getContext();
390 llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
391 llvm::Constant* alignment_constant =
392 llvm::ConstantInt::get(int64_ty, alignment);
393 llvm::MDBuilder metadata_builder(context);
394 auto* alignment_metadata =
395 metadata_builder.createConstant(alignment_constant);
396 load->setMetadata(llvm::LLVMContext::MD_align,
397 llvm::MDNode::get(context, alignment_metadata));
398 }
399
SetDereferenceableMetadataForLoad(llvm::LoadInst * load,uint64_t dereferenceable_bytes)400 void SetDereferenceableMetadataForLoad(llvm::LoadInst* load,
401 uint64_t dereferenceable_bytes) {
402 llvm::LLVMContext& context = load->getContext();
403 llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
404 llvm::Constant* dereferenceable_bytes_constant =
405 llvm::ConstantInt::get(int64_ty, dereferenceable_bytes);
406 llvm::MDBuilder metadata_builder(context);
407 auto* dereferenceable_bytes_metadata =
408 metadata_builder.createConstant(dereferenceable_bytes_constant);
409 load->setMetadata(llvm::LLVMContext::MD_dereferenceable,
410 llvm::MDNode::get(context, dereferenceable_bytes_metadata));
411 }
412
AddRangeMetadata(int32_t lower,int32_t upper,llvm::Instruction * inst)413 llvm::Instruction* AddRangeMetadata(int32_t lower, int32_t upper,
414 llvm::Instruction* inst) {
415 llvm::LLVMContext& context = inst->getParent()->getContext();
416 llvm::IntegerType* i32 = llvm::Type::getInt32Ty(context);
417 inst->setMetadata(
418 llvm::LLVMContext::MD_range,
419 llvm::MDNode::get(
420 context,
421 {llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, lower)),
422 llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, upper))}));
423 return inst;
424 }
425
IrName(absl::string_view a)426 std::string IrName(absl::string_view a) {
427 std::string s(a);
428 s.erase(std::remove(s.begin(), s.end(), '%'), s.end());
429 return s;
430 }
431
IrName(absl::string_view a,absl::string_view b)432 std::string IrName(absl::string_view a, absl::string_view b) {
433 if (!a.empty() && !b.empty()) {
434 return IrName(absl::StrCat(a, ".", b));
435 }
436 return IrName(absl::StrCat(a, b));
437 }
438
IrName(const HloInstruction * a,absl::string_view b)439 std::string IrName(const HloInstruction* a, absl::string_view b) {
440 return IrName(a->name(), b);
441 }
442
SanitizeFunctionName(std::string function_name)443 std::string SanitizeFunctionName(std::string function_name) {
444 // The backend with the strictest requirements on function names is NVPTX, so
445 // we sanitize to its requirements.
446 //
447 // A slightly stricter version of the NVPTX requirements is that names match
448 // /[a-zA-Z_$][a-zA-Z0-9_$]*/, with the exception that the names "_" and "$"
449 // are illegal.
450
451 // Sanitize chars in function_name.
452 std::transform(function_name.begin(), function_name.end(),
453 function_name.begin(), [](char c) {
454 if (('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') ||
455 ('0' <= c && c <= '9') || c == '_' || c == '$') {
456 return c;
457 }
458 return '_';
459 });
460
461 // Ensure the name isn't empty.
462 if (function_name.empty()) {
463 function_name = "__unnamed";
464 }
465
466 // Ensure the name doesn't start with a number.
467 if (!function_name.empty() && function_name[0] >= '0' &&
468 function_name[0] <= '9') {
469 function_name.insert(function_name.begin(), '_');
470 }
471
472 // Ensure the name isn't "_" or "$".
473 if (function_name == "_" || function_name == "$") {
474 function_name += '_';
475 }
476
477 return function_name;
478 }
479
SetToFirstInsertPoint(llvm::BasicBlock * blk,llvm::IRBuilder<> * builder)480 void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
481 builder->SetInsertPoint(blk, blk->getFirstInsertionPt());
482 }
483
SetToLastInsertPoint(llvm::BasicBlock * blk,llvm::IRBuilder<> * builder)484 void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
485 if (llvm::Instruction* terminator = blk->getTerminator()) {
486 builder->SetInsertPoint(terminator);
487 } else {
488 builder->SetInsertPoint(blk);
489 }
490 }
491
CreateRor(llvm::Value * rotand,llvm::Value * rotor,llvm::IRBuilder<> * builder)492 llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
493 llvm::IRBuilder<>* builder) {
494 auto size = rotand->getType()->getPrimitiveSizeInBits();
495 auto size_value = builder->getIntN(size, size);
496 auto mod = [=](llvm::Value* x) { return builder->CreateURem(x, size_value); };
497 return builder->CreateOr(
498 builder->CreateShl(rotand, mod(builder->CreateSub(size_value, rotor))),
499 builder->CreateLShr(rotand, mod(rotor)));
500 }
501
ByteSizeOf(const Shape & shape,const llvm::DataLayout & data_layout)502 int64_t ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) {
503 unsigned pointer_size = data_layout.getPointerSize();
504 return ShapeUtil::ByteSizeOf(shape, pointer_size);
505 }
506
GetCpuFastMathFlags(const HloModuleConfig & module_config)507 llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config) {
508 llvm::FastMathFlags flags;
509 const auto& options = module_config.debug_options();
510 if (!options.xla_cpu_enable_fast_math()) {
511 return flags;
512 }
513 // Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros, AllowReciprocal,
514 // AllowContract, and ApproxFunc.
515 flags.setFast();
516 flags.setNoNaNs(!options.xla_cpu_fast_math_honor_nans());
517 flags.setNoInfs(!options.xla_cpu_fast_math_honor_infs());
518 flags.setAllowReciprocal(!options.xla_cpu_fast_math_honor_division());
519 flags.setApproxFunc(!options.xla_cpu_fast_math_honor_functions());
520 return flags;
521 }
522
MergeMetadata(llvm::LLVMContext * context,const std::map<int,llvm::MDNode * > & a,const std::map<int,llvm::MDNode * > & b)523 std::map<int, llvm::MDNode*> MergeMetadata(
524 llvm::LLVMContext* context, const std::map<int, llvm::MDNode*>& a,
525 const std::map<int, llvm::MDNode*>& b) {
526 // We should extend this as needed to deal with other kinds of metadata like
527 // !dereferenceable and !range.
528
529 std::map<int, llvm::MDNode*> result;
530 for (auto kind_md_pair : a) {
531 if (kind_md_pair.first == llvm::LLVMContext::MD_alias_scope) {
532 llvm::SmallVector<llvm::Metadata*, 8> union_of_scopes;
533 llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
534 for (const auto& scope_a : kind_md_pair.second->operands()) {
535 scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
536 union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_a.get()));
537 }
538 auto it = b.find(kind_md_pair.first);
539 if (it != b.end()) {
540 for (const auto& scope_b : it->second->operands()) {
541 if (!scope_set.count(llvm::cast<llvm::MDNode>(scope_b.get()))) {
542 union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b.get()));
543 }
544 }
545 }
546 result[llvm::LLVMContext::MD_alias_scope] =
547 llvm::MDNode::get(*context, union_of_scopes);
548 } else if (kind_md_pair.first == llvm::LLVMContext::MD_noalias) {
549 llvm::SmallVector<llvm::Metadata*, 8> intersection_of_scopes;
550 llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
551 for (const auto& scope_a : kind_md_pair.second->operands()) {
552 scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
553 }
554 auto it = b.find(kind_md_pair.first);
555 if (it != b.end()) {
556 for (const auto& scope_b : it->second->operands()) {
557 if (scope_set.count(llvm::cast<llvm::MDNode>(scope_b))) {
558 intersection_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b));
559 }
560 }
561 }
562 if (!intersection_of_scopes.empty()) {
563 result[llvm::LLVMContext::MD_noalias] =
564 llvm::MDNode::get(*context, intersection_of_scopes);
565 }
566 }
567 }
568 return result;
569 }
570
CreateAndWriteStringToFile(const std::string & directory_name,const std::string & file_name,const std::string & text)571 static Status CreateAndWriteStringToFile(const std::string& directory_name,
572 const std::string& file_name,
573 const std::string& text) {
574 std::unique_ptr<tensorflow::WritableFile> f;
575 TF_RETURN_IF_ERROR(
576 tensorflow::Env::Default()->RecursivelyCreateDir(directory_name));
577 TF_RETURN_IF_ERROR(
578 tensorflow::Env::Default()->NewWritableFile(file_name, &f));
579 TF_RETURN_IF_ERROR(f->Append(text));
580 TF_RETURN_IF_ERROR(f->Close());
581 return OkStatus();
582 }
583
DumpIrIfEnabled(const HloModule & hlo_module,const llvm::Module & llvm_module,bool optimized,absl::string_view filename_suffix)584 void DumpIrIfEnabled(const HloModule& hlo_module,
585 const llvm::Module& llvm_module, bool optimized,
586 absl::string_view filename_suffix) {
587 const auto& debug_opts = hlo_module.config().debug_options();
588 if (!DumpingEnabledForHloModule(hlo_module)) {
589 return;
590 }
591 // We can end up compiling different modules with the same name when using
592 // XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously
593 // dumped from the same process in such cases.
594 std::string suffix =
595 absl::StrCat("ir-", optimized ? "with" : "no", "-opt",
596 filename_suffix.empty() ? "" : ".", filename_suffix);
597 DumpToFileInDirOrStdout(hlo_module, "", absl::StrCat(suffix, ".ll"),
598 DumpModuleToString(llvm_module));
599
600 // For some models the embedded constants can be huge, so also dump the module
601 // with the constants stripped to get IR that is easier to manipulate. Skip
602 // this if we're dumping to stdout; there's no point in duplicating everything
603 // when writing to the terminal.
604 if (!DumpingToStdout(debug_opts)) {
605 DumpToFileInDir(hlo_module, "", absl::StrCat(suffix, "-noconst.ll"),
606 DumpModuleToString(*DropConstantInitializers(llvm_module)));
607 }
608 }
609
CreateCpuFunction(llvm::FunctionType * function_type,llvm::GlobalValue::LinkageTypes linkage,const HloModuleConfig & module_config,absl::string_view name,llvm::Module * module)610 llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type,
611 llvm::GlobalValue::LinkageTypes linkage,
612 const HloModuleConfig& module_config,
613 absl::string_view name,
614 llvm::Module* module) {
615 llvm::Function* function =
616 llvm::Function::Create(function_type, linkage, AsStringRef(name), module);
617 function->setCallingConv(llvm::CallingConv::C);
618 function->addFnAttr("no-frame-pointer-elim", "false");
619
620 // Generate unwind information so that GDB can crawl through the stack frames
621 // created by the JIT compiled code.
622 function->setUWTableKind(llvm::UWTableKind::Default);
623
624 // Tensorflow always flushes denormals to zero, let LLVM know that flushing
625 // denormals is safe. This allows vectorization using ARM's neon instruction
626 // set.
627 function->addFnAttr("denormal-fp-math", "preserve-sign");
628
629 // Add the optimize attribute to the function if optimizing for size. This
630 // controls internal behavior of some optimization passes (e.g. loop
631 // unrolling).
632 if (cpu::options::OptimizeForSizeRequested(module_config)) {
633 function->addFnAttr(llvm::Attribute::OptimizeForSize);
634 }
635
636 return function;
637 }
638
UMulLowHigh32(llvm::IRBuilder<> * b,llvm::Value * src0,llvm::Value * src1)639 std::pair<llvm::Value*, llvm::Value*> UMulLowHigh32(llvm::IRBuilder<>* b,
640 llvm::Value* src0,
641 llvm::Value* src1) {
642 CHECK_EQ(src0->getType()->getPrimitiveSizeInBits(), 32);
643 CHECK_EQ(src1->getType()->getPrimitiveSizeInBits(), 32);
644 llvm::Type* int64_ty = b->getInt64Ty();
645 src0 = b->CreateZExt(src0, int64_ty);
646 src1 = b->CreateZExt(src1, int64_ty);
647 return SplitInt64ToInt32s(b, b->CreateMul(src0, src1));
648 }
649
SplitInt64ToInt32s(llvm::IRBuilder<> * b,llvm::Value * value_64bits)650 std::pair<llvm::Value*, llvm::Value*> SplitInt64ToInt32s(
651 llvm::IRBuilder<>* b, llvm::Value* value_64bits) {
652 CHECK_EQ(value_64bits->getType()->getPrimitiveSizeInBits(), 64);
653 llvm::Type* int32_ty = b->getInt32Ty();
654 llvm::Value* low_32bits = b->CreateTrunc(value_64bits, int32_ty);
655 llvm::Value* high_32bits =
656 b->CreateTrunc(b->CreateLShr(value_64bits, 32), int32_ty);
657 return std::make_pair(low_32bits, high_32bits);
658 }
659
GetGlobalMemoryAddressSpace()660 unsigned GetGlobalMemoryAddressSpace() { return 1; }
661
GetOrCreateVariableForRngState(llvm::Module * module,llvm::IRBuilder<> * b)662 llvm::GlobalVariable* GetOrCreateVariableForRngState(llvm::Module* module,
663 llvm::IRBuilder<>* b) {
664 static const char* kRngStateVariableName = "rng_state";
665 llvm::GlobalVariable* state_ptr =
666 module->getNamedGlobal(kRngStateVariableName);
667 if (!state_ptr) {
668 llvm::Type* state_type = b->getInt128Ty();
669 // Use a non-zero initial value as zero state can cause the result of the
670 // first random number generation not passing the chi-square test. The
671 // values used here are arbitrarily chosen, any non-zero values should be
672 // fine.
673 state_ptr = new llvm::GlobalVariable(
674 /*M=*/*module,
675 /*Ty=*/state_type,
676 /*isConstant=*/false,
677 /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
678 /*Initializer=*/llvm::ConstantInt::get(b->getInt128Ty(), 0x7012395ull),
679 /*Name=*/kRngStateVariableName,
680 /*InsertBefore=*/nullptr,
681 /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
682 /*AddressSpace=*/GetGlobalMemoryAddressSpace(),
683 /*isExternallyInitialized=*/false);
684 }
685 return state_ptr;
686 }
687
RngGetAndUpdateState(uint64_t delta,llvm::Module * module,llvm::IRBuilder<> * builder)688 llvm::Value* RngGetAndUpdateState(uint64_t delta, llvm::Module* module,
689 llvm::IRBuilder<>* builder) {
690 llvm::GlobalVariable* state_ptr =
691 GetOrCreateVariableForRngState(module, builder);
692 llvm::LoadInst* state_value_old =
693 builder->CreateLoad(state_ptr->getValueType(), state_ptr, "load_state");
694 llvm::Value* state_value_new = builder->CreateAdd(
695 state_value_old,
696 llvm::ConstantInt::get(state_value_old->getType(), delta));
697 builder->CreateStore(state_value_new, state_ptr);
698 return state_value_old;
699 }
700
EmitReturnBlock(llvm::IRBuilder<> * b)701 llvm::BasicBlock* EmitReturnBlock(llvm::IRBuilder<>* b) {
702 llvm::Function* function = b->GetInsertBlock()->getParent();
703 llvm::Module* module = b->GetInsertBlock()->getModule();
704 llvm::IRBuilder<>::InsertPointGuard guard(*b);
705 llvm::BasicBlock* early_return =
706 llvm::BasicBlock::Create(/*Context=*/module->getContext(),
707 /*Name=*/"early_return",
708 /*Parent=*/function);
709 b->SetInsertPoint(early_return);
710 b->CreateRetVoid();
711 return early_return;
712 }
713
EmitEarlyReturn(llvm::Value * condition,llvm::IRBuilder<> * b,llvm::BasicBlock * return_block)714 void EmitEarlyReturn(llvm::Value* condition, llvm::IRBuilder<>* b,
715 llvm::BasicBlock* return_block) {
716 if (!return_block) {
717 return_block = EmitReturnBlock(b);
718 }
719
720 llvm::BasicBlock* continued;
721
722 // Implicitly check whtether we are already at the end of unterminated block.
723 if (b->GetInsertBlock()->getTerminator() == nullptr) {
724 // If we are generating code into an incomplete basic block we can just
725 // create a new basic block to jump to after our conditional branch.
726 continued = llvm_ir::CreateBasicBlock(/*insert_before=*/nullptr,
727 /*name=*/"", b);
728 } else {
729 // If we are generating code into a basic block that already has code, we
730 // need to split that block so as to not disturb the existing code.
731 auto original = b->GetInsertBlock();
732 continued = original->splitBasicBlock(b->GetInsertPoint());
733 // Remove the auto-generated unconditional branch to replace with our
734 // conditional branch.
735 original->getTerminator()->eraseFromParent();
736 b->SetInsertPoint(original);
737 }
738
739 b->CreateCondBr(condition, continued, return_block);
740 b->SetInsertPoint(continued, continued->getFirstInsertionPt());
741 }
742
743 } // namespace llvm_ir
744 } // namespace xla
745