1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Provide helper routine for obtaining gpu target information useful
16 // for llvm IR contruction.
17
18 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
19
20 #include "absl/strings/str_cat.h"
21 #include "llvm/IR/IntrinsicsAMDGPU.h"
22 #include "llvm/IR/IntrinsicsNVPTX.h"
23 #include "llvm/IR/MDBuilder.h"
24 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_type_conversion_util.h"
25 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
26 #include "tensorflow/core/platform/logging.h"
27
28 namespace xla {
29 namespace gpu {
30 namespace {
31 // Utility functions to obtain NVPTX/AMDGPU specific information.
32 using absl::StrCat;
33
34 // Wrapper structure for carrying llvm intrinsic ids for NVPTX/AMDGPU platforms.
35 // On AMDGPU, some of these operations are made as device functions instead of
36 // intrinsics. Therefore a variant type is used to wrap the lambda to call
37 // those device functions.
38 struct TargetIntrinsics {
39 llvm::Intrinsic::ID nvptx_intrinsic;
40 std::variant<llvm::Intrinsic::ID,
41 std::function<llvm::CallInst*(llvm::IRBuilder<>*)>>
42 amdgpu_intrinsic_or_function;
43 };
44
45 // Gets the llvm intrinsic ids on different platforms (NVPTX, AMDGPU)
46 // corresponding to the give TargetIntrinsicID.
GetIntrinsic(TargetIntrinsicID intrin)47 struct TargetIntrinsics GetIntrinsic(TargetIntrinsicID intrin) {
48 switch (intrin) {
49 case TargetIntrinsicID::kThreadIdx: {
50 return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x,
51 llvm::Intrinsic::amdgcn_workitem_id_x};
52 }
53 case TargetIntrinsicID::kThreadIdy: {
54 return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y,
55 llvm::Intrinsic::amdgcn_workitem_id_y};
56 }
57 case TargetIntrinsicID::kThreadIdz: {
58 return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z,
59 llvm::Intrinsic::amdgcn_workitem_id_z};
60 }
61 case TargetIntrinsicID::kBlockIdx: {
62 return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
63 llvm::Intrinsic::amdgcn_workgroup_id_x};
64 }
65 case TargetIntrinsicID::kBlockIdy: {
66 return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
67 llvm::Intrinsic::amdgcn_workgroup_id_y};
68 }
69 case TargetIntrinsicID::kBlockIdz: {
70 return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z,
71 llvm::Intrinsic::amdgcn_workgroup_id_z};
72 }
73 case TargetIntrinsicID::kBarrierId: {
74 return {llvm::Intrinsic::nvvm_barrier0,
75 llvm::Intrinsic::amdgcn_s_barrier};
76 }
77 case TargetIntrinsicID::kBlockDimx: {
78 return {llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x,
79 [](llvm::IRBuilder<>* b_) -> llvm::CallInst* {
80 return EmitDeviceFunctionCall("__ockl_get_local_size",
81 {b_->getInt32(0)}, {U32}, U64, {},
82 b_);
83 }};
84 }
85 case TargetIntrinsicID::kBlockDimy: {
86 return {llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_y,
87 [](llvm::IRBuilder<>* b_) -> llvm::CallInst* {
88 return EmitDeviceFunctionCall("__ockl_get_local_size",
89 {b_->getInt32(1)}, {U32}, U64, {},
90 b_);
91 }};
92 }
93 case TargetIntrinsicID::kBlockDimz: {
94 return {llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_z,
95 [](llvm::IRBuilder<>* b_) -> llvm::CallInst* {
96 return EmitDeviceFunctionCall("__ockl_get_local_size",
97 {b_->getInt32(2)}, {U32}, U64, {},
98 b_);
99 }};
100 }
101 }
102 }
103
104 // Wrapper structure for carrying math functions for NVPTX/AMDGPU platforms.
105 struct TargetDeviceFunction {
106 const std::string nvptx_root;
107 const std::string amdgpu_root;
108 };
109
110 // Gets the device function name on different platforms (NVPTX, AMDGPU)
111 // corresponding to the given TargetDeviceFunctionID.
GetDeviceFunctionRoot(TargetDeviceFunctionID func_id)112 struct TargetDeviceFunction GetDeviceFunctionRoot(
113 TargetDeviceFunctionID func_id) {
114 switch (func_id) {
115 case TargetDeviceFunctionID::kAtan2: {
116 return {"__nv_atan2", "__ocml_atan2"};
117 }
118 case TargetDeviceFunctionID::kCos: {
119 return {"__nv_cos", "__ocml_cos"};
120 }
121 case TargetDeviceFunctionID::kErfcinv: {
122 return {"__nv_erfcinv", "__ocml_erfcinv"};
123 }
124 case TargetDeviceFunctionID::kExp: {
125 return {"__nv_exp", "__ocml_exp"};
126 }
127 case TargetDeviceFunctionID::kExpm1: {
128 return {"__nv_expm1", "__ocml_expm1"};
129 }
130 case TargetDeviceFunctionID::kFmod: {
131 return {"__nv_fmod", "__ocml_fmod"};
132 }
133 case TargetDeviceFunctionID::kHypot: {
134 return {"__nv_hypot", "__ocml_hypot"};
135 }
136 case TargetDeviceFunctionID::kLog: {
137 return {"__nv_log", "__ocml_log"};
138 }
139 case TargetDeviceFunctionID::kLog1p: {
140 return {"__nv_log1p", "__ocml_log1p"};
141 }
142 case TargetDeviceFunctionID::kPow: {
143 return {"__nv_pow", "__ocml_pow"};
144 }
145 case TargetDeviceFunctionID::kRound: {
146 return {"__nv_round", "__ocml_round"};
147 }
148 case TargetDeviceFunctionID::kRsqrt: {
149 return {"__nv_rsqrt", "__ocml_rsqrt"};
150 }
151 case TargetDeviceFunctionID::kSin: {
152 return {"__nv_sin", "__ocml_sin"};
153 }
154 case TargetDeviceFunctionID::kSqrt: {
155 return {"__nv_sqrt", "__ocml_sqrt"};
156 }
157 case TargetDeviceFunctionID::kTanh: {
158 return {"__nv_tanh", "__ocml_tanh"};
159 }
160 }
161 }
162 } // namespace
163
ObtainDeviceFunctionName(TargetDeviceFunctionID func_id,PrimitiveType output_type,llvm::IRBuilder<> * b)164 std::string ObtainDeviceFunctionName(TargetDeviceFunctionID func_id,
165 PrimitiveType output_type,
166 llvm::IRBuilder<>* b) {
167 // The device math functions differentiate between "double" and "float" by
168 // appending a double or float specific suffix to a root name. The suffix and
169 // the root name are specific to the target.
170 llvm::Triple target_triple =
171 llvm::Triple(b->GetInsertBlock()->getModule()->getTargetTriple());
172 struct TargetDeviceFunction gpu_root_names = GetDeviceFunctionRoot(func_id);
173 if (target_triple.isNVPTX()) {
174 if (output_type == F32) {
175 return StrCat(gpu_root_names.nvptx_root, "f");
176 } else if (output_type == F64) {
177 return gpu_root_names.nvptx_root;
178 } else {
179 LOG(FATAL) << "Unexpected type while getting device function name.";
180 }
181 } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
182 if (output_type == F32) {
183 return StrCat(gpu_root_names.amdgpu_root, "_f32");
184 } else if (output_type == F64) {
185 return StrCat(gpu_root_names.amdgpu_root, "_f64");
186 } else {
187 LOG(FATAL) << "Unexpected type while getting device function name.";
188 }
189 } else {
190 LOG(FATAL) << "Invalid triple " << target_triple.str();
191 }
192 }
193
EmitDeviceFunctionCall(const std::string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type,absl::Span<const llvm::Attribute::AttrKind> attributes,llvm::IRBuilder<> * b,absl::string_view name)194 llvm::CallInst* EmitDeviceFunctionCall(
195 const std::string& callee_name, absl::Span<llvm::Value* const> operands,
196 absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
197 absl::Span<const llvm::Attribute::AttrKind> attributes,
198 llvm::IRBuilder<>* b, absl::string_view name) {
199 std::vector<llvm::Type*> ir_input_types;
200 llvm::Module* module = b->GetInsertBlock()->getModule();
201 for (PrimitiveType input_type : input_types) {
202 ir_input_types.push_back(
203 llvm_ir::PrimitiveTypeToIrType(input_type, module));
204 }
205 llvm::FunctionType* callee_type = llvm::FunctionType::get(
206 llvm_ir::PrimitiveTypeToIrType(output_type, module), // Return type.
207 ir_input_types, // Parameter types.
208 false); // No variadic arguments.
209
210 // Declares the callee if it is not declared already.
211 llvm::Function* callee = llvm::dyn_cast<llvm::Function>(
212 b->GetInsertBlock()
213 ->getModule()
214 ->getOrInsertFunction(callee_name, callee_type)
215 .getCallee());
216
217 for (auto attribute : attributes) {
218 callee->addFnAttr(attribute);
219 }
220
221 return b->CreateCall(callee, llvm_ir::AsArrayRef(operands), name.data());
222 }
223
EmitCallToTargetIntrinsic(TargetIntrinsicID intrinsic_id,absl::Span<llvm::Value * const> operands,absl::Span<llvm::Type * const> overloaded_types,llvm::IRBuilder<> * b)224 llvm::CallInst* EmitCallToTargetIntrinsic(
225 TargetIntrinsicID intrinsic_id, absl::Span<llvm::Value* const> operands,
226 absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b) {
227 llvm::Module* module = b->GetInsertBlock()->getModule();
228 struct TargetIntrinsics gpu_intrinsic_id = GetIntrinsic(intrinsic_id);
229 llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
230 llvm::Intrinsic::ID llvm_intrinsic_id = llvm::Intrinsic::not_intrinsic;
231 if (target_triple.isNVPTX()) {
232 llvm_intrinsic_id = gpu_intrinsic_id.nvptx_intrinsic;
233 } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
234 llvm::Intrinsic::ID* llvm_intrinsic_id_ptr =
235 std::get_if<llvm::Intrinsic::ID>(
236 &gpu_intrinsic_id.amdgpu_intrinsic_or_function);
237 if (llvm_intrinsic_id_ptr) {
238 llvm_intrinsic_id = *llvm_intrinsic_id_ptr;
239 } else {
240 std::function<llvm::CallInst*(llvm::IRBuilder<>*)>* builder_func =
241 std::get_if<std::function<llvm::CallInst*(llvm::IRBuilder<>*)>>(
242 &gpu_intrinsic_id.amdgpu_intrinsic_or_function);
243 return (*builder_func)(b);
244 }
245 } else {
246 LOG(FATAL) << "Invalid triple " << target_triple.str();
247 }
248
249 llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
250 module, llvm_intrinsic_id, llvm_ir::AsArrayRef(overloaded_types));
251 return b->CreateCall(intrinsic, llvm_ir::AsArrayRef(operands));
252 }
253
AnnotateFunctionAsGpuKernel(llvm::Module * module,llvm::Function * func,llvm::IRBuilder<> * b)254 void AnnotateFunctionAsGpuKernel(llvm::Module* module, llvm::Function* func,
255 llvm::IRBuilder<>* b) {
256 llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
257 if (target_triple.isNVPTX()) {
258 // Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX
259 // treats function as a CUDA kernel.
260 llvm::LLVMContext& context = module->getContext();
261 llvm::NamedMDNode* nvvm_annotations_node =
262 module->getOrInsertNamedMetadata("nvvm.annotations");
263 nvvm_annotations_node->addOperand(llvm::MDNode::get(
264 context, {llvm::ConstantAsMetadata::get(func),
265 llvm::MDString::get(context, "kernel"),
266 llvm::ConstantAsMetadata::get(b->getInt32(1))}));
267
268 } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
269 // Attach information so AMDGPU can recognize function as a AMDGPU kernel.
270 func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
271 func->addFnAttr("amdgpu-flat-work-group-size", "1, 1024");
272 } else {
273 LOG(FATAL) << "Invalid triple " << target_triple.str();
274 }
275 }
276
277 } // namespace gpu
278 } // namespace xla
279