1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
17
18 #include <stddef.h>
19
20 #include <vector>
21
22 #include "tensorflow/core/platform/logging.h"
23 // IWYU pragma: no_include "llvm/IR/Attributes.gen.inc"
24 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/string_view.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/IR/BasicBlock.h"
29 #include "llvm/IR/DerivedTypes.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/Intrinsics.h"
32 #include "llvm/IR/Module.h"
33 #include "llvm/IR/Type.h"
34 #include "tensorflow/compiler/xla/literal.h"
35 #include "tensorflow/compiler/xla/primitive_util.h"
36 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
37 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
40 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
42 #include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h"
43 #include "tensorflow/compiler/xla/shape_util.h"
44 #include "tensorflow/compiler/xla/status_macros.h"
45 #include "tensorflow/compiler/xla/statusor.h"
46 #include "tensorflow/compiler/xla/types.h"
47 #include "tensorflow/compiler/xla/util.h"
48 #include "tensorflow/compiler/xla/window_util.h"
49 #include "tensorflow/compiler/xla/xla_data.pb.h"
50
51 namespace xla {
52 namespace gpu {
53
54 using absl::StrAppend;
55 using llvm_ir::IrArray;
56 using llvm_ir::IrName;
57 using llvm_ir::SetToFirstInsertPoint;
58
59 namespace {
60 // Returns whether operand is a floating-point literal with the given value.
IsFPLiteralWithValue(const HloInstruction * operand,float value)61 bool IsFPLiteralWithValue(const HloInstruction* operand, float value) {
62 if (operand->opcode() == HloOpcode::kConstant &&
63 operand->literal().IsAllFloat(value)) {
64 return true;
65 }
66 return operand->opcode() == HloOpcode::kBroadcast &&
67 IsFPLiteralWithValue(operand->operand(0), value);
68 }
69 } // namespace
70
GpuElementalIrEmitter(const HloModuleConfig & hlo_module_config,llvm::Module * module,llvm::IRBuilder<> * b,NestedComputer compute_nested)71 GpuElementalIrEmitter::GpuElementalIrEmitter(
72 const HloModuleConfig& hlo_module_config, llvm::Module* module,
73 llvm::IRBuilder<>* b, NestedComputer compute_nested)
74 : ElementalIrEmitter(module, b),
75 hlo_module_config_(hlo_module_config),
76 compute_nested_(std::move(compute_nested)) {}
77
EmitDeviceMathCall(TargetDeviceFunctionID funcid,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type,absl::string_view name)78 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
79 TargetDeviceFunctionID funcid, absl::Span<llvm::Value* const> operands,
80 absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
81 absl::string_view name) {
82 // Device functions dont have f16 math functions, so we convert the operands
83 // to f32 before calling the function and then convert the result back to f16.
84 bool cast_result_to_fp16 = false;
85 std::vector<llvm::Value*> converted_operands(operands.begin(),
86 operands.end());
87 std::vector<PrimitiveType> converted_input_types(input_types.begin(),
88 input_types.end());
89 switch (output_type) {
90 case F16:
91 cast_result_to_fp16 = true;
92 for (int64_t i = 0; i < operands.size(); ++i) {
93 if (input_types[i] == F16) {
94 converted_operands[i] =
95 FPCast(converted_operands[i], b()->getFloatTy());
96 converted_input_types[i] = F32;
97 }
98 }
99 output_type = F32;
100 [[fallthrough]];
101 case F32:
102 break;
103 case F64:
104 break;
105 default:
106 return Unimplemented("Bad type for device math call: %s",
107 PrimitiveType_Name(output_type));
108 }
109 const std::string& munged_callee =
110 ObtainDeviceFunctionName(funcid, output_type, b());
111 llvm::Value* result = EmitMathCall(munged_callee, converted_operands,
112 converted_input_types, output_type, name)
113 .ValueOrDie();
114 if (cast_result_to_fp16) {
115 result = FPCast(result, b()->getHalfTy());
116 }
117 return result;
118 }
119
EmitLlvmIntrinsicMathCall(const std::string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type)120 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall(
121 const std::string& callee_name, absl::Span<llvm::Value* const> operands,
122 absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
123 // llvm intrinsics differentiate between half/float/double functions via
124 // the suffixes ".f16", ".f32" and ".f64".
125 std::string munged_callee = callee_name;
126 switch (output_type) {
127 case F16:
128 StrAppend(&munged_callee, ".f16");
129 break;
130 case F32:
131 StrAppend(&munged_callee, ".f32");
132 break;
133 case F64:
134 StrAppend(&munged_callee, ".f64");
135 break;
136 default:
137 return Unimplemented("Bad type for llvm intrinsic math call: %s",
138 PrimitiveType_Name(output_type));
139 }
140 return EmitMathCall(munged_callee, operands, input_types, output_type);
141 }
142
EmitMathCall(const std::string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type,absl::string_view name)143 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
144 const std::string& callee_name, absl::Span<llvm::Value* const> operands,
145 absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
146 absl::string_view name) {
147 // Binary math functions transform are of type [T] -> T.
148 for (PrimitiveType input_type : input_types) {
149 if (output_type != input_type) {
150 return Unimplemented("Input type != output type: %s != %s",
151 PrimitiveType_Name(input_type),
152 PrimitiveType_Name(output_type));
153 }
154 }
155
156 return EmitDeviceFunctionCall(
157 callee_name, operands, input_types, output_type,
158 {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b(), name);
159 }
160
GetSourceIndexOfBitcast(const llvm_ir::IrArray::Index & index,const HloInstruction * hlo)161 llvm_ir::IrArray::Index GpuElementalIrEmitter::GetSourceIndexOfBitcast(
162 const llvm_ir::IrArray::Index& index, const HloInstruction* hlo) {
163 Shape shape = hlo->shape();
164 Shape operand_shape = hlo->operand(0)->shape();
165
166 // Decode the layout of the shape from the Protobugs attached to
167 // backend_config_.
168 BitcastBackendConfig bitcast_config;
169 CHECK(bitcast_config.ParseFromString(hlo->raw_backend_config_string()));
170
171 *shape.mutable_layout() =
172 xla::Layout::CreateFromProto(bitcast_config.result_layout());
173 *operand_shape.mutable_layout() =
174 xla::Layout::CreateFromProto(bitcast_config.source_layout());
175 return index.SourceIndexOfBitcast(shape, operand_shape, b());
176 }
177
EmitFloatBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)178 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
179 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
180 PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
181 PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
182 PrimitiveType output_type = op->shape().element_type();
183 HloOpcode opcode = op->opcode();
184
185 if (hlo_module_config_.debug_options().xla_gpu_enable_fast_min_max() &&
186 (opcode == HloOpcode::kMaximum || opcode == HloOpcode::kMinimum)) {
187 return llvm_ir::EmitCallToIntrinsic(
188 opcode == HloOpcode::kMaximum ? llvm::Intrinsic::maxnum
189 : llvm::Intrinsic::minnum,
190 {lhs_value, rhs_value}, {lhs_value->getType()}, b());
191 }
192
193 switch (op->opcode()) {
194 case HloOpcode::kRemainder: {
195 return EmitDeviceMathCall(TargetDeviceFunctionID::kFmod,
196 {lhs_value, rhs_value},
197 {lhs_input_type, rhs_input_type}, output_type);
198 }
199 case HloOpcode::kPower: {
200 return EmitPowerOp(op, lhs_value, rhs_value);
201 }
202 default:
203 return ElementalIrEmitter::EmitFloatBinaryOp(op, lhs_value, rhs_value);
204 }
205 }
206
EmitPowerOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)207 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp(
208 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
209 CHECK_EQ(op->opcode(), HloOpcode::kPower);
210 PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
211 PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
212 PrimitiveType output_type = op->shape().element_type();
213 return EmitDeviceMathCall(TargetDeviceFunctionID::kPow,
214 {lhs_value, rhs_value},
215 {lhs_input_type, rhs_input_type}, output_type);
216 }
217
EmitLog(PrimitiveType prim_type,llvm::Value * value)218 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog(PrimitiveType prim_type,
219 llvm::Value* value) {
220 return EmitDeviceMathCall(TargetDeviceFunctionID::kLog, {value}, {prim_type},
221 prim_type);
222 }
223
EmitLog1p(PrimitiveType prim_type,llvm::Value * value)224 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
225 llvm::Value* value) {
226 return EmitDeviceMathCall(TargetDeviceFunctionID::kLog1p, {value},
227 {prim_type}, prim_type);
228 }
229
EmitSin(PrimitiveType prim_type,llvm::Value * value)230 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSin(PrimitiveType prim_type,
231 llvm::Value* value) {
232 return EmitDeviceMathCall(TargetDeviceFunctionID::kSin, {value}, {prim_type},
233 prim_type);
234 }
235
EmitCos(PrimitiveType prim_type,llvm::Value * value)236 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitCos(PrimitiveType prim_type,
237 llvm::Value* value) {
238 return EmitDeviceMathCall(TargetDeviceFunctionID::kCos, {value}, {prim_type},
239 prim_type);
240 }
241
EmitExp(PrimitiveType prim_type,llvm::Value * value,absl::string_view)242 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp(
243 PrimitiveType prim_type, llvm::Value* value, absl::string_view /*name*/) {
244 return EmitDeviceMathCall(TargetDeviceFunctionID::kExp, {value}, {prim_type},
245 prim_type);
246 }
247
EmitExpm1(PrimitiveType prim_type,llvm::Value * value)248 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
249 llvm::Value* value) {
250 return EmitDeviceMathCall(TargetDeviceFunctionID::kExpm1, {value},
251 {prim_type}, prim_type);
252 }
253
EmitPow(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs,absl::string_view name)254 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type,
255 llvm::Value* lhs,
256 llvm::Value* rhs,
257 absl::string_view name) {
258 return EmitDeviceMathCall(TargetDeviceFunctionID::kPow, {lhs, rhs},
259 {prim_type, prim_type}, prim_type, name);
260 }
261
EmitSqrt(PrimitiveType prim_type,llvm::Value * value)262 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSqrt(PrimitiveType prim_type,
263 llvm::Value* value) {
264 return EmitDeviceMathCall(TargetDeviceFunctionID::kSqrt, {value}, {prim_type},
265 prim_type);
266 }
267
EmitRsqrt(PrimitiveType prim_type,llvm::Value * value)268 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type,
269 llvm::Value* value) {
270 return EmitDeviceMathCall(TargetDeviceFunctionID::kRsqrt, {value},
271 {prim_type}, prim_type);
272 }
273
EmitAtan2(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs,absl::string_view name)274 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(
275 PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs,
276 absl::string_view name) {
277 return EmitDeviceMathCall(TargetDeviceFunctionID::kAtan2, {lhs, rhs},
278 {prim_type, prim_type}, prim_type, name);
279 }
280
EmitTanh(PrimitiveType prim_type,llvm::Value * value)281 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
282 llvm::Value* value) {
283 // When F64 is being requested, assume performance is less important and use
284 // the more numerically precise tanh function.
285 if (prim_type == F64) {
286 return EmitDeviceMathCall(TargetDeviceFunctionID::kTanh, {value},
287 {prim_type}, prim_type);
288 }
289
290 // Emit a fast approximation of tanh instead of calling __nv_tanh.
291 // __nv_tanh is particularly bad because it contains branches, thus
292 // preventing LLVM's load-store vectorizer from working its magic across a
293 // function which contains tanh calls.
294 //
295 // This routine isn't numerically precise, but it's good enough for ML.
296
297 // Upcast F16 to F32 if necessary.
298 llvm::Type* type = prim_type == F16 ? b()->getFloatTy() : value->getType();
299 llvm::Value* input = FPCast(value, type);
300
301 // If |value| >= kMaxValue, tanh() is set to -1.0 or 1.0.
302 constexpr double kMaxValue = 20.0;
303 auto max_value = llvm::ConstantFP::get(type, kMaxValue);
304 llvm::Value* abs_value =
305 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {input}, {type}, b());
306
307 llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b(), input);
308 auto one = llvm::ConstantFP::get(type, 1.0);
309 auto one_with_sign = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
310 {one, input}, {type}, b());
311 return FPCast(Select(FCmpULT(abs_value, max_value), fast_tanh, one_with_sign),
312 value->getType(), "tanh");
313 }
314
EmitComplexAbs(PrimitiveType prim_type,llvm::Value * value)315 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexAbs(
316 PrimitiveType prim_type, llvm::Value* value) {
317 return EmitDeviceMathCall(TargetDeviceFunctionID::kHypot,
318 {EmitExtractReal(value), EmitExtractImag(value)},
319 {prim_type, prim_type}, prim_type);
320 }
321
EmitThreadId()322 llvm::Value* GpuElementalIrEmitter::EmitThreadId() {
323 llvm::Value* block_id = IntCast(
324 EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b()),
325 b()->getIntNTy(128), /*isSigned=*/true, "block.id");
326 llvm::Value* thread_id_in_block = IntCast(
327 EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b()),
328 b()->getIntNTy(128), /*isSigned=*/true, "thread.id");
329 llvm::Value* threads_per_block = IntCast(
330 EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b()),
331 b()->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
332 return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
333 }
334
335 } // namespace gpu
336 } // namespace xla
337