xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/base/casts.h"
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/ADT/Triple.h"
28 #include "llvm/IR/DerivedTypes.h"
29 #include "llvm/IR/GlobalValue.h"
30 #include "llvm/IR/GlobalVariable.h"
31 #include "llvm/IR/MDBuilder.h"
32 #include "llvm/IR/Operator.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Target/TargetOptions.h"
35 #include "llvm/Transforms/Utils/Cloning.h"
36 #include "tensorflow/compiler/xla/debug_options_flags.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
40 #include "tensorflow/compiler/xla/service/dump.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_type_conversion_util.h"
42 #include "tensorflow/compiler/xla/service/name_uniquer.h"
43 #include "tensorflow/compiler/xla/shape_util.h"
44 #include "tensorflow/compiler/xla/types.h"
45 #include "tensorflow/compiler/xla/util.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/platform/env.h"
48 #include "tensorflow/core/platform/logging.h"
49 
50 namespace xla {
51 namespace llvm_ir {
52 
53 namespace {
54 
55 // Note, this function is only useful in an insertion context; in a global
56 // (e.g. constants) context it will CHECK fail.
ModuleFromIRBuilder(llvm::IRBuilder<> * b)57 llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* b) {
58   auto block = CHECK_NOTNULL(b->GetInsertBlock());
59   auto fn = CHECK_NOTNULL(block->getParent());
60   auto module = CHECK_NOTNULL(fn->getParent());
61   return module;
62 }
63 
64 }  // namespace
65 
DropConstantInitializers(const llvm::Module & module)66 std::unique_ptr<llvm::Module> DropConstantInitializers(
67     const llvm::Module& module) {
68   std::unique_ptr<llvm::Module> cloned_module = CloneModule(module);
69   for (llvm::GlobalVariable& global_var : cloned_module->globals()) {
70     global_var.setInitializer(nullptr);
71     global_var.setLinkage(llvm::GlobalValue::LinkageTypes::ExternalLinkage);
72   }
73   return cloned_module;
74 }
75 
DumpModuleToString(const llvm::Module & module)76 std::string DumpModuleToString(const llvm::Module& module) {
77   std::string buffer_string;
78   llvm::raw_string_ostream ostream(buffer_string);
79   module.print(ostream, nullptr);
80   ostream.flush();
81   return buffer_string;
82 }
83 
EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id,absl::Span<llvm::Value * const> operands,absl::Span<llvm::Type * const> overloaded_types,llvm::IRBuilder<> * b,absl::string_view name)84 llvm::CallInst* EmitCallToIntrinsic(
85     llvm::Intrinsic::ID intrinsic_id, absl::Span<llvm::Value* const> operands,
86     absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b,
87     absl::string_view name) {
88   llvm::Module* module = ModuleFromIRBuilder(b);
89   llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
90       module, intrinsic_id, AsArrayRef(overloaded_types));
91   return b->CreateCall(intrinsic, AsArrayRef(operands), name.data());
92 }
93 
EmitFloatMax(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b,bool enable_fast_min_max,absl::string_view name)94 llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
95                           llvm::IRBuilder<>* b, bool enable_fast_min_max,
96                           absl::string_view name) {
97   if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
98     auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value);
99     return b->CreateSelect(cmp, lhs_value, rhs_value, name.data());
100   } else {
101     auto cmp_ge = b->CreateFCmpOGE(lhs_value, rhs_value);
102     auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
103     auto sel_lhs = b->CreateOr(cmp_ge, lhs_is_nan);
104     return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data());
105   }
106 }
107 
EmitFloatMin(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b,bool enable_fast_min_max,absl::string_view name)108 llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
109                           llvm::IRBuilder<>* b, bool enable_fast_min_max,
110                           absl::string_view name) {
111   if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
112     auto cmp = b->CreateFCmpULE(lhs_value, rhs_value);
113     return b->CreateSelect(cmp, lhs_value, rhs_value, name.data());
114   } else {
115     auto cmp_le = b->CreateFCmpOLE(lhs_value, rhs_value);
116     auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
117     auto sel_lhs = b->CreateOr(cmp_le, lhs_is_nan);
118     return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data());
119   }
120 }
121 
EmitBufferIndexingGEP(llvm::Value * array,llvm::Type * element_type,llvm::Value * index,llvm::IRBuilder<> * b)122 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Type* element_type,
123                                    llvm::Value* index, llvm::IRBuilder<>* b) {
124   llvm::Type* array_type = array->getType();
125   CHECK(array_type->isPointerTy());
126   llvm::PointerType* array_type_as_pointer =
127       llvm::cast<llvm::PointerType>(array_type);
128   CHECK(array_type_as_pointer->isOpaqueOrPointeeTypeMatches(element_type));
129   VLOG(2) << "EmitBufferIndexingGEP with type="
130           << llvm_ir::DumpToString(*array_type)
131           << " array=" << llvm_ir::DumpToString(*array)
132           << " index=" << llvm_ir::DumpToString(*index);
133 
134   return b->CreateInBoundsGEP(
135       element_type, array,
136       llvm::isa<llvm::GlobalVariable>(array)
137           ? llvm::ArrayRef<llvm::Value*>({b->getInt64(0), index})
138           : index);
139 }
140 
EmitBufferIndexingGEP(llvm::Value * array,llvm::Type * element_type,int64_t index,llvm::IRBuilder<> * b)141 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Type* element_type,
142                                    int64_t index, llvm::IRBuilder<>* b) {
143   return EmitBufferIndexingGEP(array, element_type, b->getInt64(index), b);
144 }
145 
PrimitiveTypeToIrType(PrimitiveType element_type,llvm::Module * module)146 llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
147                                   llvm::Module* module) {
148   switch (element_type) {
149     case PRED:
150     case S8:
151     case U8:
152       return llvm::Type::getInt8Ty(module->getContext());
153     case S16:
154     case U16:
155     case BF16:
156       // For BF16 we just need some type that is 16 bits wide so that it will
157       // take up the right amount of space in memory. LLVM does not have a BF16
158       // type (the LLVM half type is IEEE 16 bit floating point, not bfloat), so
159       // we can't map it directly to an LLVM type. We will not map a BF16
160       // addition to an addition on this type (int16_t) - this is just the type
161       // used for storage.
162       return llvm::Type::getInt16Ty(module->getContext());
163     case F16:
164       return llvm::Type::getHalfTy(module->getContext());
165     case S32:
166     case U32:
167       return llvm::Type::getInt32Ty(module->getContext());
168     case S64:
169     case U64:
170       return llvm::Type::getInt64Ty(module->getContext());
171     case F32:
172       return llvm::Type::getFloatTy(module->getContext());
173     case F64:
174       return llvm::Type::getDoubleTy(module->getContext());
175     case C64: {
176       auto cplx_t =
177           llvm::StructType::getTypeByName(module->getContext(), "complex64");
178       if (cplx_t == nullptr) {
179         // C++ standard dictates the memory layout of std::complex is contiguous
180         // real followed by imaginary. C++11 section 26.4 [complex.numbers]:
181         // If z is an lvalue expression of type cv std::complex<T> then the
182         // expression reinterpret_cast<cv T(&)[2]>(z) shall be well-formed,
183         // reinterpret_cast<cv T(&)[2]>(z)[0] shall designate the real part of
184         // z, and reinterpret_cast<cv T(&)[2]>(z)[1] shall designate the
185         // imaginary part of z.
186         return llvm::StructType::create(
187             {llvm::Type::getFloatTy(module->getContext()),
188              llvm::Type::getFloatTy(module->getContext())},
189             "complex64", /*isPacked=*/true);
190       }
191       return cplx_t;
192     }
193     case C128: {
194       auto cplx_t =
195           llvm::StructType::getTypeByName(module->getContext(), "complex128");
196       if (cplx_t == nullptr) {
197         return llvm::StructType::create(
198             {llvm::Type::getDoubleTy(module->getContext()),
199              llvm::Type::getDoubleTy(module->getContext())},
200             "complex128", /*isPacked=*/true);
201       }
202       return cplx_t;
203     }  // A Tuple contains an array of pointers. Use i8*.
204     case TUPLE:
205     // An Opaque is like a void*, use i8*.
206     case OPAQUE_TYPE:
207       return llvm::Type::getInt8PtrTy(module->getContext());
208     case TOKEN:
209       // Tokens do not have a physical representation, but the compiler needs
210       // some placeholder type, so use int8_t*.
211       return llvm::Type::getInt8PtrTy(module->getContext());
212     default:
213       LOG(FATAL) << "unsupported type " << element_type;
214   }
215 }
216 
GetSizeInBits(llvm::Type * type)217 int GetSizeInBits(llvm::Type* type) {
218   const llvm::StructType* struct_ty = llvm::dyn_cast<llvm::StructType>(type);
219   if (struct_ty) {
220     CHECK(struct_ty->isPacked());
221     int bits = 0;
222     for (auto element_type : struct_ty->elements()) {
223       bits += GetSizeInBits(element_type);
224     }
225     return bits;
226   }
227   int bits = type->getPrimitiveSizeInBits();
228   CHECK_GT(bits, 0) << "type is not sized";
229   return bits;
230 }
231 
ShapeToIrType(const Shape & shape,llvm::Module * module)232 llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) {
233   llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module);
234   if (shape.IsTuple()) {
235     // A tuple buffer is an array of pointers.
236     result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size());
237   } else if (shape.IsArray()) {
238     for (int64_t dimension : LayoutUtil::MinorToMajor(shape)) {
239       result_type =
240           llvm::ArrayType::get(result_type, shape.dimensions(dimension));
241     }
242   }
243   return result_type;
244 }
245 
EncodeSelfDescribingShapeConstant(const Shape & shape,int32_t * shape_size,llvm::IRBuilder<> * b)246 StatusOr<llvm::Value*> EncodeSelfDescribingShapeConstant(const Shape& shape,
247                                                          int32_t* shape_size,
248                                                          llvm::IRBuilder<>* b) {
249   std::string encoded_shape = shape.SerializeAsString();
250   if (encoded_shape.size() > std::numeric_limits<int32_t>::max()) {
251     return InternalError("Encoded shape size exceeded int32_t size limit.");
252   }
253   *shape_size = static_cast<int32_t>(encoded_shape.size());
254   return b->CreateGlobalStringPtr(encoded_shape);
255 }
256 
ConvertLiteralToIrConstant(const Literal & literal,llvm::Module * module)257 llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
258                                            llvm::Module* module) {
259   const char* data = static_cast<const char*>(literal.untyped_data());
260   CHECK_EQ(module->getDataLayout().isLittleEndian(),
261            tensorflow::port::kLittleEndian);
262   return llvm::ConstantDataArray::getString(
263       module->getContext(), llvm::StringRef(data, literal.size_bytes()),
264       /*AddNull=*/false);
265 }
266 
AllocateSharedMemoryTile(llvm::Module * module,llvm::Type * tile_type,absl::string_view name)267 llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module,
268                                                llvm::Type* tile_type,
269                                                absl::string_view name) {
270   // Both AMDGPU and NVPTX use the same address space for shared memory.
271   const int kGPUSharedMemoryAddrSpace = 3;
272   return new llvm::GlobalVariable(
273       *module, tile_type,
274       /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
275       llvm::UndefValue::get(tile_type), AsStringRef(name), nullptr,
276       llvm::GlobalValue::NotThreadLocal, kGPUSharedMemoryAddrSpace);
277 }
278 
EmitAllocaAtFunctionEntry(llvm::Type * type,absl::string_view name,llvm::IRBuilder<> * b,int alignment)279 llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
280                                             absl::string_view name,
281                                             llvm::IRBuilder<>* b,
282                                             int alignment) {
283   return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, b, alignment);
284 }
285 
EmitAllocaAtFunctionEntryWithCount(llvm::Type * type,llvm::Value * element_count,absl::string_view name,llvm::IRBuilder<> * b,int alignment)286 llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type,
287                                                      llvm::Value* element_count,
288                                                      absl::string_view name,
289                                                      llvm::IRBuilder<>* b,
290                                                      int alignment) {
291   llvm::IRBuilder<>::InsertPointGuard guard(*b);
292   llvm::Function* function = b->GetInsertBlock()->getParent();
293   b->SetInsertPoint(&function->getEntryBlock(),
294                     function->getEntryBlock().getFirstInsertionPt());
295   llvm::AllocaInst* alloca =
296       b->CreateAlloca(type, element_count, AsStringRef(name));
297   if (alignment != 0) {
298     alloca->setAlignment(llvm::Align(alignment));
299   }
300   return alloca;
301 }
302 
CreateBasicBlock(llvm::BasicBlock * insert_before,absl::string_view name,llvm::IRBuilder<> * b)303 llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
304                                    absl::string_view name,
305                                    llvm::IRBuilder<>* b) {
306   return llvm::BasicBlock::Create(
307       /*Context=*/b->getContext(),
308       /*Name=*/AsStringRef(name),
309       /*Parent=*/b->GetInsertBlock()->getParent(),
310       /*InsertBefore*/ insert_before);
311 }
312 
EmitIfThenElse(llvm::Value * condition,absl::string_view name,llvm::IRBuilder<> * b,bool emit_else)313 LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
314                           llvm::IRBuilder<>* b, bool emit_else) {
315   llvm_ir::LlvmIfData if_data;
316   if_data.if_block = b->GetInsertBlock();
317   if_data.true_block =
318       CreateBasicBlock(nullptr, absl::StrCat(name, "-true"), b);
319   if_data.false_block =
320       emit_else ? CreateBasicBlock(nullptr, absl::StrCat(name, "-false"), b)
321                 : nullptr;
322 
323   // Add a terminator to the if block, if necessary.
324   if (if_data.if_block->getTerminator() == nullptr) {
325     b->SetInsertPoint(if_data.if_block);
326     if_data.after_block =
327         CreateBasicBlock(nullptr, absl::StrCat(name, "-after"), b);
328     b->CreateBr(if_data.after_block);
329   } else {
330     if_data.after_block = if_data.if_block->splitBasicBlock(
331         b->GetInsertPoint(), absl::StrCat(name, "-after"));
332   }
333 
334   // Our basic block should now end with an unconditional branch.  Remove it;
335   // we're going to replace it with a conditional branch.
336   if_data.if_block->getTerminator()->eraseFromParent();
337 
338   b->SetInsertPoint(if_data.if_block);
339   b->CreateCondBr(condition, if_data.true_block,
340                   emit_else ? if_data.false_block : if_data.after_block);
341 
342   b->SetInsertPoint(if_data.true_block);
343   b->CreateBr(if_data.after_block);
344 
345   if (emit_else) {
346     b->SetInsertPoint(if_data.false_block);
347     b->CreateBr(if_data.after_block);
348   }
349 
350   b->SetInsertPoint(if_data.after_block,
351                     if_data.after_block->getFirstInsertionPt());
352 
353   return if_data;
354 }
355 
EmitComparison(llvm::CmpInst::Predicate predicate,llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b,absl::string_view name)356 llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
357                             llvm::Value* lhs_value, llvm::Value* rhs_value,
358                             llvm::IRBuilder<>* b, absl::string_view name) {
359   llvm::Value* comparison_result;
360   if (lhs_value->getType()->isIntegerTy()) {
361     comparison_result =
362         b->CreateICmp(predicate, lhs_value, rhs_value, name.data());
363   } else {
364     comparison_result =
365         b->CreateFCmp(predicate, lhs_value, rhs_value, name.data());
366   }
367   // comparison_result is i1, but the NVPTX codegen incorrectly lowers i1
368   // arrays. So we extend it to i8 so that it's addressable.
369   return b->CreateZExt(comparison_result, llvm_ir::PrimitiveTypeToIrType(
370                                               PRED, ModuleFromIRBuilder(b)));
371 }
372 
373 // Internal helper that is called from emitted code to log an int64_t value with
374 // a tag.
LogS64(const char * tag,int64_t value)375 static void LogS64(const char* tag, int64_t value) {
376   LOG(INFO) << tag << " (int64_t): " << value;
377 }
378 
EmitLogging(const char * tag,llvm::Value * value,llvm::IRBuilder<> * b)379 void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* b) {
380   llvm::FunctionType* log_function_type = llvm::FunctionType::get(
381       b->getVoidTy(), {b->getInt64Ty(), b->getInt64Ty()}, /*isVarArg=*/false);
382   b->CreateCall(log_function_type,
383                 b->CreateIntToPtr(b->getInt64(absl::bit_cast<int64_t>(&LogS64)),
384                                   log_function_type->getPointerTo()),
385                 {b->getInt64(absl::bit_cast<int64_t>(tag)), value});
386 }
387 
SetAlignmentMetadataForLoad(llvm::LoadInst * load,uint64_t alignment)388 void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) {
389   llvm::LLVMContext& context = load->getContext();
390   llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
391   llvm::Constant* alignment_constant =
392       llvm::ConstantInt::get(int64_ty, alignment);
393   llvm::MDBuilder metadata_builder(context);
394   auto* alignment_metadata =
395       metadata_builder.createConstant(alignment_constant);
396   load->setMetadata(llvm::LLVMContext::MD_align,
397                     llvm::MDNode::get(context, alignment_metadata));
398 }
399 
SetDereferenceableMetadataForLoad(llvm::LoadInst * load,uint64_t dereferenceable_bytes)400 void SetDereferenceableMetadataForLoad(llvm::LoadInst* load,
401                                        uint64_t dereferenceable_bytes) {
402   llvm::LLVMContext& context = load->getContext();
403   llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
404   llvm::Constant* dereferenceable_bytes_constant =
405       llvm::ConstantInt::get(int64_ty, dereferenceable_bytes);
406   llvm::MDBuilder metadata_builder(context);
407   auto* dereferenceable_bytes_metadata =
408       metadata_builder.createConstant(dereferenceable_bytes_constant);
409   load->setMetadata(llvm::LLVMContext::MD_dereferenceable,
410                     llvm::MDNode::get(context, dereferenceable_bytes_metadata));
411 }
412 
AddRangeMetadata(int32_t lower,int32_t upper,llvm::Instruction * inst)413 llvm::Instruction* AddRangeMetadata(int32_t lower, int32_t upper,
414                                     llvm::Instruction* inst) {
415   llvm::LLVMContext& context = inst->getParent()->getContext();
416   llvm::IntegerType* i32 = llvm::Type::getInt32Ty(context);
417   inst->setMetadata(
418       llvm::LLVMContext::MD_range,
419       llvm::MDNode::get(
420           context,
421           {llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, lower)),
422            llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, upper))}));
423   return inst;
424 }
425 
IrName(absl::string_view a)426 std::string IrName(absl::string_view a) {
427   std::string s(a);
428   s.erase(std::remove(s.begin(), s.end(), '%'), s.end());
429   return s;
430 }
431 
IrName(absl::string_view a,absl::string_view b)432 std::string IrName(absl::string_view a, absl::string_view b) {
433   if (!a.empty() && !b.empty()) {
434     return IrName(absl::StrCat(a, ".", b));
435   }
436   return IrName(absl::StrCat(a, b));
437 }
438 
IrName(const HloInstruction * a,absl::string_view b)439 std::string IrName(const HloInstruction* a, absl::string_view b) {
440   return IrName(a->name(), b);
441 }
442 
SanitizeFunctionName(std::string function_name)443 std::string SanitizeFunctionName(std::string function_name) {
444   // The backend with the strictest requirements on function names is NVPTX, so
445   // we sanitize to its requirements.
446   //
447   // A slightly stricter version of the NVPTX requirements is that names match
448   // /[a-zA-Z_$][a-zA-Z0-9_$]*/, with the exception that the names "_" and "$"
449   // are illegal.
450 
451   // Sanitize chars in function_name.
452   std::transform(function_name.begin(), function_name.end(),
453                  function_name.begin(), [](char c) {
454                    if (('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') ||
455                        ('0' <= c && c <= '9') || c == '_' || c == '$') {
456                      return c;
457                    }
458                    return '_';
459                  });
460 
461   // Ensure the name isn't empty.
462   if (function_name.empty()) {
463     function_name = "__unnamed";
464   }
465 
466   // Ensure the name doesn't start with a number.
467   if (!function_name.empty() && function_name[0] >= '0' &&
468       function_name[0] <= '9') {
469     function_name.insert(function_name.begin(), '_');
470   }
471 
472   // Ensure the name isn't "_" or "$".
473   if (function_name == "_" || function_name == "$") {
474     function_name += '_';
475   }
476 
477   return function_name;
478 }
479 
SetToFirstInsertPoint(llvm::BasicBlock * blk,llvm::IRBuilder<> * builder)480 void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
481   builder->SetInsertPoint(blk, blk->getFirstInsertionPt());
482 }
483 
SetToLastInsertPoint(llvm::BasicBlock * blk,llvm::IRBuilder<> * builder)484 void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
485   if (llvm::Instruction* terminator = blk->getTerminator()) {
486     builder->SetInsertPoint(terminator);
487   } else {
488     builder->SetInsertPoint(blk);
489   }
490 }
491 
CreateRor(llvm::Value * rotand,llvm::Value * rotor,llvm::IRBuilder<> * builder)492 llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
493                        llvm::IRBuilder<>* builder) {
494   auto size = rotand->getType()->getPrimitiveSizeInBits();
495   auto size_value = builder->getIntN(size, size);
496   auto mod = [=](llvm::Value* x) { return builder->CreateURem(x, size_value); };
497   return builder->CreateOr(
498       builder->CreateShl(rotand, mod(builder->CreateSub(size_value, rotor))),
499       builder->CreateLShr(rotand, mod(rotor)));
500 }
501 
ByteSizeOf(const Shape & shape,const llvm::DataLayout & data_layout)502 int64_t ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) {
503   unsigned pointer_size = data_layout.getPointerSize();
504   return ShapeUtil::ByteSizeOf(shape, pointer_size);
505 }
506 
GetCpuFastMathFlags(const HloModuleConfig & module_config)507 llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config) {
508   llvm::FastMathFlags flags;
509   const auto& options = module_config.debug_options();
510   if (!options.xla_cpu_enable_fast_math()) {
511     return flags;
512   }
513   // Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros, AllowReciprocal,
514   // AllowContract, and ApproxFunc.
515   flags.setFast();
516   flags.setNoNaNs(!options.xla_cpu_fast_math_honor_nans());
517   flags.setNoInfs(!options.xla_cpu_fast_math_honor_infs());
518   flags.setAllowReciprocal(!options.xla_cpu_fast_math_honor_division());
519   flags.setApproxFunc(!options.xla_cpu_fast_math_honor_functions());
520   return flags;
521 }
522 
MergeMetadata(llvm::LLVMContext * context,const std::map<int,llvm::MDNode * > & a,const std::map<int,llvm::MDNode * > & b)523 std::map<int, llvm::MDNode*> MergeMetadata(
524     llvm::LLVMContext* context, const std::map<int, llvm::MDNode*>& a,
525     const std::map<int, llvm::MDNode*>& b) {
526   // We should extend this as needed to deal with other kinds of metadata like
527   // !dereferenceable and !range.
528 
529   std::map<int, llvm::MDNode*> result;
530   for (auto kind_md_pair : a) {
531     if (kind_md_pair.first == llvm::LLVMContext::MD_alias_scope) {
532       llvm::SmallVector<llvm::Metadata*, 8> union_of_scopes;
533       llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
534       for (const auto& scope_a : kind_md_pair.second->operands()) {
535         scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
536         union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_a.get()));
537       }
538       auto it = b.find(kind_md_pair.first);
539       if (it != b.end()) {
540         for (const auto& scope_b : it->second->operands()) {
541           if (!scope_set.count(llvm::cast<llvm::MDNode>(scope_b.get()))) {
542             union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b.get()));
543           }
544         }
545       }
546       result[llvm::LLVMContext::MD_alias_scope] =
547           llvm::MDNode::get(*context, union_of_scopes);
548     } else if (kind_md_pair.first == llvm::LLVMContext::MD_noalias) {
549       llvm::SmallVector<llvm::Metadata*, 8> intersection_of_scopes;
550       llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
551       for (const auto& scope_a : kind_md_pair.second->operands()) {
552         scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
553       }
554       auto it = b.find(kind_md_pair.first);
555       if (it != b.end()) {
556         for (const auto& scope_b : it->second->operands()) {
557           if (scope_set.count(llvm::cast<llvm::MDNode>(scope_b))) {
558             intersection_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b));
559           }
560         }
561       }
562       if (!intersection_of_scopes.empty()) {
563         result[llvm::LLVMContext::MD_noalias] =
564             llvm::MDNode::get(*context, intersection_of_scopes);
565       }
566     }
567   }
568   return result;
569 }
570 
CreateAndWriteStringToFile(const std::string & directory_name,const std::string & file_name,const std::string & text)571 static Status CreateAndWriteStringToFile(const std::string& directory_name,
572                                          const std::string& file_name,
573                                          const std::string& text) {
574   std::unique_ptr<tensorflow::WritableFile> f;
575   TF_RETURN_IF_ERROR(
576       tensorflow::Env::Default()->RecursivelyCreateDir(directory_name));
577   TF_RETURN_IF_ERROR(
578       tensorflow::Env::Default()->NewWritableFile(file_name, &f));
579   TF_RETURN_IF_ERROR(f->Append(text));
580   TF_RETURN_IF_ERROR(f->Close());
581   return OkStatus();
582 }
583 
DumpIrIfEnabled(const HloModule & hlo_module,const llvm::Module & llvm_module,bool optimized,absl::string_view filename_suffix)584 void DumpIrIfEnabled(const HloModule& hlo_module,
585                      const llvm::Module& llvm_module, bool optimized,
586                      absl::string_view filename_suffix) {
587   const auto& debug_opts = hlo_module.config().debug_options();
588   if (!DumpingEnabledForHloModule(hlo_module)) {
589     return;
590   }
591   // We can end up compiling different modules with the same name when using
592   // XlaJitCompiledCpuFunction::Compile.  Avoid overwriting IR files previously
593   // dumped from the same process in such cases.
594   std::string suffix =
595       absl::StrCat("ir-", optimized ? "with" : "no", "-opt",
596                    filename_suffix.empty() ? "" : ".", filename_suffix);
597   DumpToFileInDirOrStdout(hlo_module, "", absl::StrCat(suffix, ".ll"),
598                           DumpModuleToString(llvm_module));
599 
600   // For some models the embedded constants can be huge, so also dump the module
601   // with the constants stripped to get IR that is easier to manipulate.  Skip
602   // this if we're dumping to stdout; there's no point in duplicating everything
603   // when writing to the terminal.
604   if (!DumpingToStdout(debug_opts)) {
605     DumpToFileInDir(hlo_module, "", absl::StrCat(suffix, "-noconst.ll"),
606                     DumpModuleToString(*DropConstantInitializers(llvm_module)));
607   }
608 }
609 
CreateCpuFunction(llvm::FunctionType * function_type,llvm::GlobalValue::LinkageTypes linkage,const HloModuleConfig & module_config,absl::string_view name,llvm::Module * module)610 llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type,
611                                   llvm::GlobalValue::LinkageTypes linkage,
612                                   const HloModuleConfig& module_config,
613                                   absl::string_view name,
614                                   llvm::Module* module) {
615   llvm::Function* function =
616       llvm::Function::Create(function_type, linkage, AsStringRef(name), module);
617   function->setCallingConv(llvm::CallingConv::C);
618   function->addFnAttr("no-frame-pointer-elim", "false");
619 
620   // Generate unwind information so that GDB can crawl through the stack frames
621   // created by the JIT compiled code.
622   function->setUWTableKind(llvm::UWTableKind::Default);
623 
624   // Tensorflow always flushes denormals to zero, let LLVM know that flushing
625   // denormals is safe. This allows vectorization using ARM's neon instruction
626   // set.
627   function->addFnAttr("denormal-fp-math", "preserve-sign");
628 
629   // Add the optimize attribute to the function if optimizing for size. This
630   // controls internal behavior of some optimization passes (e.g. loop
631   // unrolling).
632   if (cpu::options::OptimizeForSizeRequested(module_config)) {
633     function->addFnAttr(llvm::Attribute::OptimizeForSize);
634   }
635 
636   return function;
637 }
638 
UMulLowHigh32(llvm::IRBuilder<> * b,llvm::Value * src0,llvm::Value * src1)639 std::pair<llvm::Value*, llvm::Value*> UMulLowHigh32(llvm::IRBuilder<>* b,
640                                                     llvm::Value* src0,
641                                                     llvm::Value* src1) {
642   CHECK_EQ(src0->getType()->getPrimitiveSizeInBits(), 32);
643   CHECK_EQ(src1->getType()->getPrimitiveSizeInBits(), 32);
644   llvm::Type* int64_ty = b->getInt64Ty();
645   src0 = b->CreateZExt(src0, int64_ty);
646   src1 = b->CreateZExt(src1, int64_ty);
647   return SplitInt64ToInt32s(b, b->CreateMul(src0, src1));
648 }
649 
SplitInt64ToInt32s(llvm::IRBuilder<> * b,llvm::Value * value_64bits)650 std::pair<llvm::Value*, llvm::Value*> SplitInt64ToInt32s(
651     llvm::IRBuilder<>* b, llvm::Value* value_64bits) {
652   CHECK_EQ(value_64bits->getType()->getPrimitiveSizeInBits(), 64);
653   llvm::Type* int32_ty = b->getInt32Ty();
654   llvm::Value* low_32bits = b->CreateTrunc(value_64bits, int32_ty);
655   llvm::Value* high_32bits =
656       b->CreateTrunc(b->CreateLShr(value_64bits, 32), int32_ty);
657   return std::make_pair(low_32bits, high_32bits);
658 }
659 
GetGlobalMemoryAddressSpace()660 unsigned GetGlobalMemoryAddressSpace() { return 1; }
661 
GetOrCreateVariableForRngState(llvm::Module * module,llvm::IRBuilder<> * b)662 llvm::GlobalVariable* GetOrCreateVariableForRngState(llvm::Module* module,
663                                                      llvm::IRBuilder<>* b) {
664   static const char* kRngStateVariableName = "rng_state";
665   llvm::GlobalVariable* state_ptr =
666       module->getNamedGlobal(kRngStateVariableName);
667   if (!state_ptr) {
668     llvm::Type* state_type = b->getInt128Ty();
669     // Use a non-zero initial value as zero state can cause the result of the
670     // first random number generation not passing the chi-square test. The
671     // values used here are arbitrarily chosen, any non-zero values should be
672     // fine.
673     state_ptr = new llvm::GlobalVariable(
674         /*M=*/*module,
675         /*Ty=*/state_type,
676         /*isConstant=*/false,
677         /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
678         /*Initializer=*/llvm::ConstantInt::get(b->getInt128Ty(), 0x7012395ull),
679         /*Name=*/kRngStateVariableName,
680         /*InsertBefore=*/nullptr,
681         /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
682         /*AddressSpace=*/GetGlobalMemoryAddressSpace(),
683         /*isExternallyInitialized=*/false);
684   }
685   return state_ptr;
686 }
687 
RngGetAndUpdateState(uint64_t delta,llvm::Module * module,llvm::IRBuilder<> * builder)688 llvm::Value* RngGetAndUpdateState(uint64_t delta, llvm::Module* module,
689                                   llvm::IRBuilder<>* builder) {
690   llvm::GlobalVariable* state_ptr =
691       GetOrCreateVariableForRngState(module, builder);
692   llvm::LoadInst* state_value_old =
693       builder->CreateLoad(state_ptr->getValueType(), state_ptr, "load_state");
694   llvm::Value* state_value_new = builder->CreateAdd(
695       state_value_old,
696       llvm::ConstantInt::get(state_value_old->getType(), delta));
697   builder->CreateStore(state_value_new, state_ptr);
698   return state_value_old;
699 }
700 
EmitReturnBlock(llvm::IRBuilder<> * b)701 llvm::BasicBlock* EmitReturnBlock(llvm::IRBuilder<>* b) {
702   llvm::Function* function = b->GetInsertBlock()->getParent();
703   llvm::Module* module = b->GetInsertBlock()->getModule();
704   llvm::IRBuilder<>::InsertPointGuard guard(*b);
705   llvm::BasicBlock* early_return =
706       llvm::BasicBlock::Create(/*Context=*/module->getContext(),
707                                /*Name=*/"early_return",
708                                /*Parent=*/function);
709   b->SetInsertPoint(early_return);
710   b->CreateRetVoid();
711   return early_return;
712 }
713 
EmitEarlyReturn(llvm::Value * condition,llvm::IRBuilder<> * b,llvm::BasicBlock * return_block)714 void EmitEarlyReturn(llvm::Value* condition, llvm::IRBuilder<>* b,
715                      llvm::BasicBlock* return_block) {
716   if (!return_block) {
717     return_block = EmitReturnBlock(b);
718   }
719 
720   llvm::BasicBlock* continued;
721 
722   // Implicitly check whtether we are already at the end of unterminated block.
723   if (b->GetInsertBlock()->getTerminator() == nullptr) {
724     // If we are generating code into an incomplete basic block we can just
725     // create a new basic block to jump to after our conditional branch.
726     continued = llvm_ir::CreateBasicBlock(/*insert_before=*/nullptr,
727                                           /*name=*/"", b);
728   } else {
729     // If we are generating code into a basic block that already has code, we
730     // need to split that block so as to not disturb the existing code.
731     auto original = b->GetInsertBlock();
732     continued = original->splitBasicBlock(b->GetInsertPoint());
733     // Remove the auto-generated unconditional branch to replace with our
734     // conditional branch.
735     original->getTerminator()->eraseFromParent();
736     b->SetInsertPoint(original);
737   }
738 
739   b->CreateCondBr(condition, continued, return_block);
740   b->SetInsertPoint(continued, continued->getFirstInsertionPt());
741 }
742 
743 }  // namespace llvm_ir
744 }  // namespace xla
745