xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
17 
18 #include <stddef.h>
19 
20 #include <vector>
21 
22 #include "tensorflow/core/platform/logging.h"
23 // IWYU pragma: no_include "llvm/IR/Attributes.gen.inc"
24 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/string_view.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/IR/BasicBlock.h"
29 #include "llvm/IR/DerivedTypes.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/Intrinsics.h"
32 #include "llvm/IR/Module.h"
33 #include "llvm/IR/Type.h"
34 #include "tensorflow/compiler/xla/literal.h"
35 #include "tensorflow/compiler/xla/primitive_util.h"
36 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
37 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
40 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
42 #include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h"
43 #include "tensorflow/compiler/xla/shape_util.h"
44 #include "tensorflow/compiler/xla/status_macros.h"
45 #include "tensorflow/compiler/xla/statusor.h"
46 #include "tensorflow/compiler/xla/types.h"
47 #include "tensorflow/compiler/xla/util.h"
48 #include "tensorflow/compiler/xla/window_util.h"
49 #include "tensorflow/compiler/xla/xla_data.pb.h"
50 
51 namespace xla {
52 namespace gpu {
53 
54 using absl::StrAppend;
55 using llvm_ir::IrArray;
56 using llvm_ir::IrName;
57 using llvm_ir::SetToFirstInsertPoint;
58 
59 namespace {
60 // Returns whether operand is a floating-point literal with the given value.
IsFPLiteralWithValue(const HloInstruction * operand,float value)61 bool IsFPLiteralWithValue(const HloInstruction* operand, float value) {
62   if (operand->opcode() == HloOpcode::kConstant &&
63       operand->literal().IsAllFloat(value)) {
64     return true;
65   }
66   return operand->opcode() == HloOpcode::kBroadcast &&
67          IsFPLiteralWithValue(operand->operand(0), value);
68 }
69 }  // namespace
70 
GpuElementalIrEmitter(const HloModuleConfig & hlo_module_config,llvm::Module * module,llvm::IRBuilder<> * b,NestedComputer compute_nested)71 GpuElementalIrEmitter::GpuElementalIrEmitter(
72     const HloModuleConfig& hlo_module_config, llvm::Module* module,
73     llvm::IRBuilder<>* b, NestedComputer compute_nested)
74     : ElementalIrEmitter(module, b),
75       hlo_module_config_(hlo_module_config),
76       compute_nested_(std::move(compute_nested)) {}
77 
EmitDeviceMathCall(TargetDeviceFunctionID funcid,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type,absl::string_view name)78 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
79     TargetDeviceFunctionID funcid, absl::Span<llvm::Value* const> operands,
80     absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
81     absl::string_view name) {
82   // Device functions dont have f16 math functions, so we convert the operands
83   // to f32 before calling the function and then convert the result back to f16.
84   bool cast_result_to_fp16 = false;
85   std::vector<llvm::Value*> converted_operands(operands.begin(),
86                                                operands.end());
87   std::vector<PrimitiveType> converted_input_types(input_types.begin(),
88                                                    input_types.end());
89   switch (output_type) {
90     case F16:
91       cast_result_to_fp16 = true;
92       for (int64_t i = 0; i < operands.size(); ++i) {
93         if (input_types[i] == F16) {
94           converted_operands[i] =
95               FPCast(converted_operands[i], b()->getFloatTy());
96           converted_input_types[i] = F32;
97         }
98       }
99       output_type = F32;
100       [[fallthrough]];
101     case F32:
102       break;
103     case F64:
104       break;
105     default:
106       return Unimplemented("Bad type for device math call: %s",
107                            PrimitiveType_Name(output_type));
108   }
109   const std::string& munged_callee =
110       ObtainDeviceFunctionName(funcid, output_type, b());
111   llvm::Value* result = EmitMathCall(munged_callee, converted_operands,
112                                      converted_input_types, output_type, name)
113                             .ValueOrDie();
114   if (cast_result_to_fp16) {
115     result = FPCast(result, b()->getHalfTy());
116   }
117   return result;
118 }
119 
EmitLlvmIntrinsicMathCall(const std::string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type)120 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall(
121     const std::string& callee_name, absl::Span<llvm::Value* const> operands,
122     absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
123   // llvm intrinsics differentiate between half/float/double functions via
124   // the suffixes ".f16", ".f32" and ".f64".
125   std::string munged_callee = callee_name;
126   switch (output_type) {
127     case F16:
128       StrAppend(&munged_callee, ".f16");
129       break;
130     case F32:
131       StrAppend(&munged_callee, ".f32");
132       break;
133     case F64:
134       StrAppend(&munged_callee, ".f64");
135       break;
136     default:
137       return Unimplemented("Bad type for llvm intrinsic math call: %s",
138                            PrimitiveType_Name(output_type));
139   }
140   return EmitMathCall(munged_callee, operands, input_types, output_type);
141 }
142 
EmitMathCall(const std::string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type,absl::string_view name)143 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
144     const std::string& callee_name, absl::Span<llvm::Value* const> operands,
145     absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
146     absl::string_view name) {
147   // Binary math functions transform are of type [T] -> T.
148   for (PrimitiveType input_type : input_types) {
149     if (output_type != input_type) {
150       return Unimplemented("Input type != output type: %s != %s",
151                            PrimitiveType_Name(input_type),
152                            PrimitiveType_Name(output_type));
153     }
154   }
155 
156   return EmitDeviceFunctionCall(
157       callee_name, operands, input_types, output_type,
158       {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b(), name);
159 }
160 
GetSourceIndexOfBitcast(const llvm_ir::IrArray::Index & index,const HloInstruction * hlo)161 llvm_ir::IrArray::Index GpuElementalIrEmitter::GetSourceIndexOfBitcast(
162     const llvm_ir::IrArray::Index& index, const HloInstruction* hlo) {
163   Shape shape = hlo->shape();
164   Shape operand_shape = hlo->operand(0)->shape();
165 
166   // Decode the layout of the shape from the Protobugs attached to
167   // backend_config_.
168   BitcastBackendConfig bitcast_config;
169   CHECK(bitcast_config.ParseFromString(hlo->raw_backend_config_string()));
170 
171   *shape.mutable_layout() =
172       xla::Layout::CreateFromProto(bitcast_config.result_layout());
173   *operand_shape.mutable_layout() =
174       xla::Layout::CreateFromProto(bitcast_config.source_layout());
175   return index.SourceIndexOfBitcast(shape, operand_shape, b());
176 }
177 
EmitFloatBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)178 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
179     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
180   PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
181   PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
182   PrimitiveType output_type = op->shape().element_type();
183   HloOpcode opcode = op->opcode();
184 
185   if (hlo_module_config_.debug_options().xla_gpu_enable_fast_min_max() &&
186       (opcode == HloOpcode::kMaximum || opcode == HloOpcode::kMinimum)) {
187     return llvm_ir::EmitCallToIntrinsic(
188         opcode == HloOpcode::kMaximum ? llvm::Intrinsic::maxnum
189                                       : llvm::Intrinsic::minnum,
190         {lhs_value, rhs_value}, {lhs_value->getType()}, b());
191   }
192 
193   switch (op->opcode()) {
194     case HloOpcode::kRemainder: {
195       return EmitDeviceMathCall(TargetDeviceFunctionID::kFmod,
196                                 {lhs_value, rhs_value},
197                                 {lhs_input_type, rhs_input_type}, output_type);
198     }
199     case HloOpcode::kPower: {
200       return EmitPowerOp(op, lhs_value, rhs_value);
201     }
202     default:
203       return ElementalIrEmitter::EmitFloatBinaryOp(op, lhs_value, rhs_value);
204   }
205 }
206 
EmitPowerOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)207 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp(
208     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
209   CHECK_EQ(op->opcode(), HloOpcode::kPower);
210   PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
211   PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
212   PrimitiveType output_type = op->shape().element_type();
213   return EmitDeviceMathCall(TargetDeviceFunctionID::kPow,
214                             {lhs_value, rhs_value},
215                             {lhs_input_type, rhs_input_type}, output_type);
216 }
217 
EmitLog(PrimitiveType prim_type,llvm::Value * value)218 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog(PrimitiveType prim_type,
219                                                       llvm::Value* value) {
220   return EmitDeviceMathCall(TargetDeviceFunctionID::kLog, {value}, {prim_type},
221                             prim_type);
222 }
223 
EmitLog1p(PrimitiveType prim_type,llvm::Value * value)224 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
225                                                         llvm::Value* value) {
226   return EmitDeviceMathCall(TargetDeviceFunctionID::kLog1p, {value},
227                             {prim_type}, prim_type);
228 }
229 
EmitSin(PrimitiveType prim_type,llvm::Value * value)230 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSin(PrimitiveType prim_type,
231                                                       llvm::Value* value) {
232   return EmitDeviceMathCall(TargetDeviceFunctionID::kSin, {value}, {prim_type},
233                             prim_type);
234 }
235 
EmitCos(PrimitiveType prim_type,llvm::Value * value)236 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitCos(PrimitiveType prim_type,
237                                                       llvm::Value* value) {
238   return EmitDeviceMathCall(TargetDeviceFunctionID::kCos, {value}, {prim_type},
239                             prim_type);
240 }
241 
EmitExp(PrimitiveType prim_type,llvm::Value * value,absl::string_view)242 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp(
243     PrimitiveType prim_type, llvm::Value* value, absl::string_view /*name*/) {
244   return EmitDeviceMathCall(TargetDeviceFunctionID::kExp, {value}, {prim_type},
245                             prim_type);
246 }
247 
EmitExpm1(PrimitiveType prim_type,llvm::Value * value)248 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
249                                                         llvm::Value* value) {
250   return EmitDeviceMathCall(TargetDeviceFunctionID::kExpm1, {value},
251                             {prim_type}, prim_type);
252 }
253 
EmitPow(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs,absl::string_view name)254 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type,
255                                                       llvm::Value* lhs,
256                                                       llvm::Value* rhs,
257                                                       absl::string_view name) {
258   return EmitDeviceMathCall(TargetDeviceFunctionID::kPow, {lhs, rhs},
259                             {prim_type, prim_type}, prim_type, name);
260 }
261 
EmitSqrt(PrimitiveType prim_type,llvm::Value * value)262 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSqrt(PrimitiveType prim_type,
263                                                        llvm::Value* value) {
264   return EmitDeviceMathCall(TargetDeviceFunctionID::kSqrt, {value}, {prim_type},
265                             prim_type);
266 }
267 
EmitRsqrt(PrimitiveType prim_type,llvm::Value * value)268 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type,
269                                                         llvm::Value* value) {
270   return EmitDeviceMathCall(TargetDeviceFunctionID::kRsqrt, {value},
271                             {prim_type}, prim_type);
272 }
273 
EmitAtan2(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs,absl::string_view name)274 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(
275     PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs,
276     absl::string_view name) {
277   return EmitDeviceMathCall(TargetDeviceFunctionID::kAtan2, {lhs, rhs},
278                             {prim_type, prim_type}, prim_type, name);
279 }
280 
EmitTanh(PrimitiveType prim_type,llvm::Value * value)281 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
282                                                        llvm::Value* value) {
283   // When F64 is being requested, assume performance is less important and use
284   // the more numerically precise tanh function.
285   if (prim_type == F64) {
286     return EmitDeviceMathCall(TargetDeviceFunctionID::kTanh, {value},
287                               {prim_type}, prim_type);
288   }
289 
290   // Emit a fast approximation of tanh instead of calling __nv_tanh.
291   // __nv_tanh is particularly bad because it contains branches, thus
292   // preventing LLVM's load-store vectorizer from working its magic across a
293   // function which contains tanh calls.
294   //
295   // This routine isn't numerically precise, but it's good enough for ML.
296 
297   // Upcast F16 to F32 if necessary.
298   llvm::Type* type = prim_type == F16 ? b()->getFloatTy() : value->getType();
299   llvm::Value* input = FPCast(value, type);
300 
301   // If |value| >= kMaxValue, tanh() is set to -1.0 or 1.0.
302   constexpr double kMaxValue = 20.0;
303   auto max_value = llvm::ConstantFP::get(type, kMaxValue);
304   llvm::Value* abs_value =
305       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {input}, {type}, b());
306 
307   llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b(), input);
308   auto one = llvm::ConstantFP::get(type, 1.0);
309   auto one_with_sign = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
310                                                     {one, input}, {type}, b());
311   return FPCast(Select(FCmpULT(abs_value, max_value), fast_tanh, one_with_sign),
312                 value->getType(), "tanh");
313 }
314 
EmitComplexAbs(PrimitiveType prim_type,llvm::Value * value)315 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexAbs(
316     PrimitiveType prim_type, llvm::Value* value) {
317   return EmitDeviceMathCall(TargetDeviceFunctionID::kHypot,
318                             {EmitExtractReal(value), EmitExtractImag(value)},
319                             {prim_type, prim_type}, prim_type);
320 }
321 
EmitThreadId()322 llvm::Value* GpuElementalIrEmitter::EmitThreadId() {
323   llvm::Value* block_id = IntCast(
324       EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b()),
325       b()->getIntNTy(128), /*isSigned=*/true, "block.id");
326   llvm::Value* thread_id_in_block = IntCast(
327       EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b()),
328       b()->getIntNTy(128), /*isSigned=*/true, "thread.id");
329   llvm::Value* threads_per_block = IntCast(
330       EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b()),
331       b()->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
332   return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
333 }
334 
335 }  // namespace gpu
336 }  // namespace xla
337