xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/target_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Provide helper routine for obtaining  gpu target information useful
16 // for llvm IR contruction.
17 
18 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
19 
20 #include "absl/strings/str_cat.h"
21 #include "llvm/IR/IntrinsicsAMDGPU.h"
22 #include "llvm/IR/IntrinsicsNVPTX.h"
23 #include "llvm/IR/MDBuilder.h"
24 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_type_conversion_util.h"
25 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
26 #include "tensorflow/core/platform/logging.h"
27 
28 namespace xla {
29 namespace gpu {
30 namespace {
31 // Utility functions to obtain NVPTX/AMDGPU specific information.
32 using absl::StrCat;
33 
34 // Wrapper structure for carrying llvm intrinsic ids for NVPTX/AMDGPU platforms.
35 // On AMDGPU, some of these operations are made as device functions instead of
36 // intrinsics. Therefore a variant type is used to wrap the lambda to call
37 // those device functions.
38 struct TargetIntrinsics {
39   llvm::Intrinsic::ID nvptx_intrinsic;
40   std::variant<llvm::Intrinsic::ID,
41                std::function<llvm::CallInst*(llvm::IRBuilder<>*)>>
42       amdgpu_intrinsic_or_function;
43 };
44 
45 // Gets the llvm intrinsic ids on different platforms (NVPTX, AMDGPU)
46 // corresponding to the give TargetIntrinsicID.
GetIntrinsic(TargetIntrinsicID intrin)47 struct TargetIntrinsics GetIntrinsic(TargetIntrinsicID intrin) {
48   switch (intrin) {
49     case TargetIntrinsicID::kThreadIdx: {
50       return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x,
51               llvm::Intrinsic::amdgcn_workitem_id_x};
52     }
53     case TargetIntrinsicID::kThreadIdy: {
54       return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y,
55               llvm::Intrinsic::amdgcn_workitem_id_y};
56     }
57     case TargetIntrinsicID::kThreadIdz: {
58       return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z,
59               llvm::Intrinsic::amdgcn_workitem_id_z};
60     }
61     case TargetIntrinsicID::kBlockIdx: {
62       return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
63               llvm::Intrinsic::amdgcn_workgroup_id_x};
64     }
65     case TargetIntrinsicID::kBlockIdy: {
66       return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
67               llvm::Intrinsic::amdgcn_workgroup_id_y};
68     }
69     case TargetIntrinsicID::kBlockIdz: {
70       return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z,
71               llvm::Intrinsic::amdgcn_workgroup_id_z};
72     }
73     case TargetIntrinsicID::kBarrierId: {
74       return {llvm::Intrinsic::nvvm_barrier0,
75               llvm::Intrinsic::amdgcn_s_barrier};
76     }
77     case TargetIntrinsicID::kBlockDimx: {
78       return {llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x,
79               [](llvm::IRBuilder<>* b_) -> llvm::CallInst* {
80                 return EmitDeviceFunctionCall("__ockl_get_local_size",
81                                               {b_->getInt32(0)}, {U32}, U64, {},
82                                               b_);
83               }};
84     }
85     case TargetIntrinsicID::kBlockDimy: {
86       return {llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_y,
87               [](llvm::IRBuilder<>* b_) -> llvm::CallInst* {
88                 return EmitDeviceFunctionCall("__ockl_get_local_size",
89                                               {b_->getInt32(1)}, {U32}, U64, {},
90                                               b_);
91               }};
92     }
93     case TargetIntrinsicID::kBlockDimz: {
94       return {llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_z,
95               [](llvm::IRBuilder<>* b_) -> llvm::CallInst* {
96                 return EmitDeviceFunctionCall("__ockl_get_local_size",
97                                               {b_->getInt32(2)}, {U32}, U64, {},
98                                               b_);
99               }};
100     }
101   }
102 }
103 
104 // Wrapper structure for carrying math functions for NVPTX/AMDGPU platforms.
105 struct TargetDeviceFunction {
106   const std::string nvptx_root;
107   const std::string amdgpu_root;
108 };
109 
110 // Gets the device function name on different platforms (NVPTX, AMDGPU)
111 // corresponding to the given TargetDeviceFunctionID.
GetDeviceFunctionRoot(TargetDeviceFunctionID func_id)112 struct TargetDeviceFunction GetDeviceFunctionRoot(
113     TargetDeviceFunctionID func_id) {
114   switch (func_id) {
115     case TargetDeviceFunctionID::kAtan2: {
116       return {"__nv_atan2", "__ocml_atan2"};
117     }
118     case TargetDeviceFunctionID::kCos: {
119       return {"__nv_cos", "__ocml_cos"};
120     }
121     case TargetDeviceFunctionID::kErfcinv: {
122       return {"__nv_erfcinv", "__ocml_erfcinv"};
123     }
124     case TargetDeviceFunctionID::kExp: {
125       return {"__nv_exp", "__ocml_exp"};
126     }
127     case TargetDeviceFunctionID::kExpm1: {
128       return {"__nv_expm1", "__ocml_expm1"};
129     }
130     case TargetDeviceFunctionID::kFmod: {
131       return {"__nv_fmod", "__ocml_fmod"};
132     }
133     case TargetDeviceFunctionID::kHypot: {
134       return {"__nv_hypot", "__ocml_hypot"};
135     }
136     case TargetDeviceFunctionID::kLog: {
137       return {"__nv_log", "__ocml_log"};
138     }
139     case TargetDeviceFunctionID::kLog1p: {
140       return {"__nv_log1p", "__ocml_log1p"};
141     }
142     case TargetDeviceFunctionID::kPow: {
143       return {"__nv_pow", "__ocml_pow"};
144     }
145     case TargetDeviceFunctionID::kRound: {
146       return {"__nv_round", "__ocml_round"};
147     }
148     case TargetDeviceFunctionID::kRsqrt: {
149       return {"__nv_rsqrt", "__ocml_rsqrt"};
150     }
151     case TargetDeviceFunctionID::kSin: {
152       return {"__nv_sin", "__ocml_sin"};
153     }
154     case TargetDeviceFunctionID::kSqrt: {
155       return {"__nv_sqrt", "__ocml_sqrt"};
156     }
157     case TargetDeviceFunctionID::kTanh: {
158       return {"__nv_tanh", "__ocml_tanh"};
159     }
160   }
161 }
162 }  // namespace
163 
ObtainDeviceFunctionName(TargetDeviceFunctionID func_id,PrimitiveType output_type,llvm::IRBuilder<> * b)164 std::string ObtainDeviceFunctionName(TargetDeviceFunctionID func_id,
165                                      PrimitiveType output_type,
166                                      llvm::IRBuilder<>* b) {
167   // The device math functions differentiate between "double" and "float" by
168   // appending a double or float specific suffix to a root name. The suffix and
169   // the root name are specific to the target.
170   llvm::Triple target_triple =
171       llvm::Triple(b->GetInsertBlock()->getModule()->getTargetTriple());
172   struct TargetDeviceFunction gpu_root_names = GetDeviceFunctionRoot(func_id);
173   if (target_triple.isNVPTX()) {
174     if (output_type == F32) {
175       return StrCat(gpu_root_names.nvptx_root, "f");
176     } else if (output_type == F64) {
177       return gpu_root_names.nvptx_root;
178     } else {
179       LOG(FATAL) << "Unexpected type while getting device function name.";
180     }
181   } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
182     if (output_type == F32) {
183       return StrCat(gpu_root_names.amdgpu_root, "_f32");
184     } else if (output_type == F64) {
185       return StrCat(gpu_root_names.amdgpu_root, "_f64");
186     } else {
187       LOG(FATAL) << "Unexpected type while getting device function name.";
188     }
189   } else {
190     LOG(FATAL) << "Invalid triple " << target_triple.str();
191   }
192 }
193 
EmitDeviceFunctionCall(const std::string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type,absl::Span<const llvm::Attribute::AttrKind> attributes,llvm::IRBuilder<> * b,absl::string_view name)194 llvm::CallInst* EmitDeviceFunctionCall(
195     const std::string& callee_name, absl::Span<llvm::Value* const> operands,
196     absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
197     absl::Span<const llvm::Attribute::AttrKind> attributes,
198     llvm::IRBuilder<>* b, absl::string_view name) {
199   std::vector<llvm::Type*> ir_input_types;
200   llvm::Module* module = b->GetInsertBlock()->getModule();
201   for (PrimitiveType input_type : input_types) {
202     ir_input_types.push_back(
203         llvm_ir::PrimitiveTypeToIrType(input_type, module));
204   }
205   llvm::FunctionType* callee_type = llvm::FunctionType::get(
206       llvm_ir::PrimitiveTypeToIrType(output_type, module),  // Return type.
207       ir_input_types,                                       // Parameter types.
208       false);  // No variadic arguments.
209 
210   // Declares the callee if it is not declared already.
211   llvm::Function* callee = llvm::dyn_cast<llvm::Function>(
212       b->GetInsertBlock()
213           ->getModule()
214           ->getOrInsertFunction(callee_name, callee_type)
215           .getCallee());
216 
217   for (auto attribute : attributes) {
218     callee->addFnAttr(attribute);
219   }
220 
221   return b->CreateCall(callee, llvm_ir::AsArrayRef(operands), name.data());
222 }
223 
EmitCallToTargetIntrinsic(TargetIntrinsicID intrinsic_id,absl::Span<llvm::Value * const> operands,absl::Span<llvm::Type * const> overloaded_types,llvm::IRBuilder<> * b)224 llvm::CallInst* EmitCallToTargetIntrinsic(
225     TargetIntrinsicID intrinsic_id, absl::Span<llvm::Value* const> operands,
226     absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b) {
227   llvm::Module* module = b->GetInsertBlock()->getModule();
228   struct TargetIntrinsics gpu_intrinsic_id = GetIntrinsic(intrinsic_id);
229   llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
230   llvm::Intrinsic::ID llvm_intrinsic_id = llvm::Intrinsic::not_intrinsic;
231   if (target_triple.isNVPTX()) {
232     llvm_intrinsic_id = gpu_intrinsic_id.nvptx_intrinsic;
233   } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
234     llvm::Intrinsic::ID* llvm_intrinsic_id_ptr =
235         std::get_if<llvm::Intrinsic::ID>(
236             &gpu_intrinsic_id.amdgpu_intrinsic_or_function);
237     if (llvm_intrinsic_id_ptr) {
238       llvm_intrinsic_id = *llvm_intrinsic_id_ptr;
239     } else {
240       std::function<llvm::CallInst*(llvm::IRBuilder<>*)>* builder_func =
241           std::get_if<std::function<llvm::CallInst*(llvm::IRBuilder<>*)>>(
242               &gpu_intrinsic_id.amdgpu_intrinsic_or_function);
243       return (*builder_func)(b);
244     }
245   } else {
246     LOG(FATAL) << "Invalid triple " << target_triple.str();
247   }
248 
249   llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
250       module, llvm_intrinsic_id, llvm_ir::AsArrayRef(overloaded_types));
251   return b->CreateCall(intrinsic, llvm_ir::AsArrayRef(operands));
252 }
253 
AnnotateFunctionAsGpuKernel(llvm::Module * module,llvm::Function * func,llvm::IRBuilder<> * b)254 void AnnotateFunctionAsGpuKernel(llvm::Module* module, llvm::Function* func,
255                                  llvm::IRBuilder<>* b) {
256   llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
257   if (target_triple.isNVPTX()) {
258     // Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX
259     // treats function as a CUDA kernel.
260     llvm::LLVMContext& context = module->getContext();
261     llvm::NamedMDNode* nvvm_annotations_node =
262         module->getOrInsertNamedMetadata("nvvm.annotations");
263     nvvm_annotations_node->addOperand(llvm::MDNode::get(
264         context, {llvm::ConstantAsMetadata::get(func),
265                   llvm::MDString::get(context, "kernel"),
266                   llvm::ConstantAsMetadata::get(b->getInt32(1))}));
267 
268   } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
269     // Attach information so AMDGPU can recognize function as a AMDGPU kernel.
270     func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
271     func->addFnAttr("amdgpu-flat-work-group-size", "1, 1024");
272   } else {
273     LOG(FATAL) << "Invalid triple " << target_triple.str();
274   }
275 }
276 
277 }  // namespace gpu
278 }  // namespace xla
279