1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <vector>
23
24 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/strings/str_cat.h"
28 #include "llvm/IR/BasicBlock.h"
29 #include "llvm/IR/Constants.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/Intrinsics.h"
32 #include "llvm/Support/MathExtras.h"
33 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
34 #include "tensorflow/compiler/xla/primitive_util.h"
35 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
36 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
37 #include "tensorflow/compiler/xla/service/hlo_module.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
40 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
42 #include "tensorflow/compiler/xla/shape_util.h"
43 #include "tensorflow/compiler/xla/status_macros.h"
44 #include "tensorflow/compiler/xla/statusor.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/compiler/xla/window_util.h"
48 #include "tensorflow/compiler/xla/xla_data.pb.h"
49 #include "tensorflow/core/platform/logging.h"
50
51 namespace xla {
52
53 using absl::StrCat;
54 using llvm_ir::IrArray;
55 using llvm_ir::IrName;
56 using llvm_ir::SetToFirstInsertPoint;
57
58 namespace {
59
EmitReducePrecisionIR(PrimitiveType src_ty,llvm::Value * x,int64_t dest_exponent_bits,int64_t dest_mantissa_bits,bool quiet_nans,llvm::IRBuilder<> * b)60 StatusOr<llvm::Value*> EmitReducePrecisionIR(
61 PrimitiveType src_ty, llvm::Value* x, int64_t dest_exponent_bits,
62 int64_t dest_mantissa_bits, bool quiet_nans, llvm::IRBuilder<>* b) {
63 using llvm::APInt;
64
65 if (!primitive_util::IsFloatingPointType(src_ty)) {
66 return Unimplemented(
67 "ReducePrecision cannot accept non-floating-point type %s.",
68 PrimitiveType_Name(src_ty));
69 }
70
71 // Integer and float types for casting and constant generation.
72 llvm::Type* float_type = x->getType();
73 int64_t nbits = float_type->getPrimitiveSizeInBits();
74 llvm::IntegerType* int_type = b->getIntNTy(nbits);
75
76 // SignificandWidth includes the implicit extra bit.
77 int src_mantissa_bits = primitive_util::SignificandWidth(src_ty) - 1;
78 int src_exponent_bits = nbits - 1 - src_mantissa_bits;
79
80 // Cast the input value to an integer for bitwise manipulation.
81 llvm::Value* x_as_int = b->CreateBitCast(x, int_type);
82
83 // Clear the sign bit, it does not participate in rounding and we will restore
84 // it later.
85 APInt sign_bit_mask(nbits, 1);
86 sign_bit_mask <<= nbits - 1;
87 llvm::Value* x_abs_bits =
88 b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, ~sign_bit_mask));
89
90 APInt exp_bits_mask(nbits, 1);
91 exp_bits_mask = ((exp_bits_mask << src_exponent_bits) - 1)
92 << src_mantissa_bits;
93 auto x_is_nan = b->CreateICmpUGT(
94 x_abs_bits, llvm::ConstantInt::get(int_type, exp_bits_mask));
95
96 if (dest_mantissa_bits < src_mantissa_bits) {
97 // Last remaining mantissa bit.
98 APInt last_mantissa_bit_mask(nbits, 1);
99 last_mantissa_bit_mask <<= src_mantissa_bits - dest_mantissa_bits;
100
101 // Compute rounding bias for round-to-nearest with ties to even. This is
102 // equal to a base value of 0111... plus one bit if the last remaining
103 // mantissa bit is 1.
104 APInt base_rounding_bias = last_mantissa_bit_mask.lshr(1) - 1;
105 llvm::Value* x_last_mantissa_bit = b->CreateLShr(
106 b->CreateAnd(x_as_int,
107 llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
108 (src_mantissa_bits - dest_mantissa_bits));
109 llvm::Value* x_rounding_bias =
110 b->CreateAdd(x_last_mantissa_bit,
111 llvm::ConstantInt::get(int_type, base_rounding_bias));
112
113 // Add rounding bias, and mask out truncated bits. Note that the case
114 // where adding the rounding bias overflows into the exponent bits is
115 // correct; the non-masked mantissa bits will all be zero, and the
116 // exponent will be incremented by one.
117 APInt truncation_mask = ~(last_mantissa_bit_mask - 1);
118 llvm::Value* x_rounded = b->CreateAdd(x_as_int, x_rounding_bias);
119 x_rounded = b->CreateAnd(x_rounded,
120 llvm::ConstantInt::get(int_type, truncation_mask));
121 if (quiet_nans) {
122 x_as_int = b->CreateSelect(x_is_nan, x_as_int, x_rounded);
123 } else {
124 x_as_int = x_rounded;
125 }
126 }
127
128 if (dest_exponent_bits < src_exponent_bits) {
129 // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
130 // significant bit -- is equal to 1.0f for all exponent sizes. Adding
131 // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
132 // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
133 // exponent (corresponding to 0.0f).
134 //
135 // Thus, the f32 exponent corresponding to the highest non-infinite
136 // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
137 // exponent corresponding to the lowest exponent for a bit size of n is
138 // (2^7-1) - 2^(n-1)-1.
139 //
140 // Note that we have already checked that exponents_bits >= 1.
141 APInt exponent_bias(nbits, 1);
142 exponent_bias = (exponent_bias << (src_exponent_bits - 1)) - 1;
143
144 APInt reduced_exponent_bias(nbits, 1);
145 reduced_exponent_bias =
146 (reduced_exponent_bias << (dest_exponent_bits - 1)) - 1;
147
148 APInt reduced_max_exponent = exponent_bias + reduced_exponent_bias;
149 APInt reduced_min_exponent = exponent_bias - reduced_exponent_bias;
150
151 // Do we overflow or underflow?
152 llvm::Value* x_exponent =
153 b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, exp_bits_mask));
154 llvm::Value* x_overflows = b->CreateICmpUGT(
155 x_exponent, llvm::ConstantInt::get(
156 int_type, reduced_max_exponent << src_mantissa_bits));
157 llvm::Value* x_underflows = b->CreateICmpULE(
158 x_exponent, llvm::ConstantInt::get(
159 int_type, reduced_min_exponent << src_mantissa_bits));
160
161 // Compute appropriately-signed values of zero and infinity.
162 llvm::Value* x_signed_zero =
163 b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, sign_bit_mask));
164 llvm::Value* x_signed_inf = b->CreateOr(
165 x_signed_zero, llvm::ConstantInt::get(int_type, exp_bits_mask));
166
167 // Force to zero or infinity if overflow or underflow. (Note that this
168 // truncates all denormal values to zero, rather than rounding them.)
169 x_as_int = b->CreateSelect(x_overflows, x_signed_inf, x_as_int);
170 x_as_int = b->CreateSelect(x_underflows, x_signed_zero, x_as_int);
171 }
172
173 // Cast the result back to a floating-point type.
174 llvm::Value* result = b->CreateBitCast(x_as_int, float_type);
175
176 // Correct result for NaN inputs.
177 //
178 // The exponent handling will "normalize" NaN values to infinities, which is
179 // undesirable (except in the case with no mantissa bits, in which case it
180 // is mandatory). This logic also handles cases where mantissa-rounding
181 // causes a NaN's mantissa to overflow into the exponent bits, which would
182 // otherwise create an erroneous zero value.
183
184 if (dest_mantissa_bits > 0) {
185 if (quiet_nans) {
186 APInt qnan_mask(nbits, 1);
187 qnan_mask <<= src_mantissa_bits - 1;
188 llvm::Value* x_with_qnan_bit_set =
189 b->CreateOr(x_as_int, llvm::ConstantInt::get(int_type, qnan_mask));
190 x_with_qnan_bit_set = b->CreateBitCast(x_with_qnan_bit_set, float_type);
191 result = b->CreateSelect(x_is_nan, x_with_qnan_bit_set, result);
192 } else {
193 result = b->CreateSelect(x_is_nan, x, result);
194 }
195 } else {
196 result = b->CreateSelect(x_is_nan,
197 llvm::ConstantFP::getInfinity(float_type), result);
198 }
199
200 return result;
201 }
202
EmitF32ToBF16(llvm::Value * f32_value,llvm::IRBuilder<> * b)203 StatusOr<llvm::Value*> EmitF32ToBF16(llvm::Value* f32_value,
204 llvm::IRBuilder<>* b) {
205 TF_ASSIGN_OR_RETURN(
206 auto reduced_precision,
207 EmitReducePrecisionIR(
208 /*src_ty=*/F32, f32_value,
209 /*dest_exponent_bits=*/primitive_util::ExponentWidth(BF16),
210 /*dest_mantissa_bits=*/primitive_util::SignificandWidth(BF16) - 1,
211 /*quiet_nans=*/true, b));
212 auto as_int32 = b->CreateBitCast(reduced_precision, b->getInt32Ty());
213 auto shifted = b->CreateLShr(as_int32, 16);
214 auto truncated = b->CreateTrunc(shifted, b->getInt16Ty());
215 return b->CreateBitCast(truncated, b->getInt16Ty());
216 }
217
EmitBF16ToF32(llvm::Value * bf16_value,llvm::IRBuilder<> * b)218 llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value, llvm::IRBuilder<>* b) {
219 auto as_int16 = b->CreateBitCast(bf16_value, b->getInt16Ty());
220 auto as_int32 = b->CreateZExt(as_int16, b->getInt32Ty());
221 auto shifted = b->CreateShl(as_int32, 16);
222 return b->CreateBitCast(shifted, b->getFloatTy());
223 }
224
EmitIntegralToFloating(llvm::Value * integer_value,PrimitiveType from_type,PrimitiveType to_type,llvm::Module * module,llvm::IRBuilder<> * b)225 llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value,
226 PrimitiveType from_type,
227 PrimitiveType to_type, llvm::Module* module,
228 llvm::IRBuilder<>* b) {
229 if (primitive_util::IsSignedIntegralType(from_type)) {
230 return b->CreateSIToFP(integer_value,
231 llvm_ir::PrimitiveTypeToIrType(to_type, module));
232 } else {
233 CHECK(primitive_util::IsUnsignedIntegralType(from_type) ||
234 from_type == PRED);
235 return b->CreateUIToFP(integer_value,
236 llvm_ir::PrimitiveTypeToIrType(to_type, module));
237 }
238 }
239
240 } // namespace
241
EmitUnaryOp(const HloInstruction * op,llvm::Value * operand_value)242 StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
243 const HloInstruction* op, llvm::Value* operand_value) {
244 if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
245 op->operand(0)->shape().element_type() == PRED) {
246 return EmitIntegerUnaryOp(op, operand_value);
247 } else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) {
248 return EmitComplexUnaryOp(op, operand_value);
249 } else {
250 return EmitFloatUnaryOp(op, operand_value);
251 }
252 }
253
EmitIntegerUnaryOp(const HloInstruction * op,llvm::Value * operand_value)254 StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
255 const HloInstruction* op, llvm::Value* operand_value) {
256 switch (op->opcode()) {
257 case HloOpcode::kConvert: {
258 PrimitiveType from_type = op->operand(0)->shape().element_type();
259 PrimitiveType to_type = op->shape().element_type();
260 CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED)
261 << from_type;
262 if (from_type == to_type) {
263 return operand_value;
264 }
265 if (to_type == PRED) {
266 return b_->CreateZExt(
267 ICmpNE(operand_value,
268 llvm::ConstantInt::get(operand_value->getType(), 0)),
269 llvm_ir::PrimitiveTypeToIrType(PRED, module_));
270 }
271 if (primitive_util::IsIntegralType(to_type)) {
272 return IntCast(operand_value,
273 llvm_ir::PrimitiveTypeToIrType(to_type, module_),
274 primitive_util::IsSignedIntegralType(from_type));
275 }
276 if (primitive_util::IsFloatingPointType(to_type)) {
277 if (to_type == BF16) {
278 return EmitF32ToBF16(EmitIntegralToFloating(operand_value, from_type,
279 F32, module_, b_),
280 b_);
281 }
282 return EmitIntegralToFloating(operand_value, from_type, to_type,
283 module_, b_);
284 }
285 if (primitive_util::IsComplexType(to_type)) {
286 auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(
287 primitive_util::ComplexComponentType(to_type), module_);
288 if (primitive_util::IsSignedIntegralType(from_type)) {
289 return EmitComposeComplex(
290 op, SIToFP(operand_value, to_ir_component_type), nullptr);
291 }
292 if (primitive_util::IsUnsignedIntegralType(from_type) ||
293 from_type == PRED) {
294 return EmitComposeComplex(
295 op, UIToFP(operand_value, to_ir_component_type), nullptr);
296 }
297 }
298 return Unimplemented("conversion from primitive type %s to %s",
299 PrimitiveType_Name(from_type),
300 PrimitiveType_Name(to_type));
301 }
302 case HloOpcode::kBitcastConvert: {
303 PrimitiveType from_type = op->operand(0)->shape().element_type();
304 PrimitiveType to_type = op->shape().element_type();
305 CHECK(primitive_util::IsIntegralType(from_type));
306 if (from_type == to_type) {
307 return operand_value;
308 }
309 if (primitive_util::BitWidth(from_type) ==
310 primitive_util::BitWidth(to_type)) {
311 return BitCast(operand_value,
312 llvm_ir::PrimitiveTypeToIrType(to_type, module_));
313 }
314 return InvalidArgument(
315 "bitcast conversion from primitive type %s to %s with unequal "
316 "bit-widths (%u versus %u) ",
317 PrimitiveType_Name(from_type), PrimitiveType_Name(to_type),
318 primitive_util::BitWidth(from_type),
319 primitive_util::BitWidth(to_type));
320 }
321 case HloOpcode::kAbs: {
322 bool is_signed =
323 primitive_util::IsSignedIntegralType(op->shape().element_type());
324 if (is_signed) {
325 auto type =
326 llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
327 auto cmp = ICmpSGE(operand_value, GetZero(type));
328 return Select(cmp, operand_value, Neg(operand_value));
329 } else {
330 return operand_value;
331 }
332 }
333 case HloOpcode::kClz: {
334 auto is_zero_undef = b_->getFalse();
335 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctlz,
336 {operand_value, is_zero_undef},
337 {operand_value->getType()}, b_);
338 }
339 case HloOpcode::kSign: {
340 CHECK(primitive_util::IsSignedIntegralType(op->shape().element_type()))
341 << op->shape().element_type();
342 auto type =
343 llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
344 auto cmp = ICmpEQ(operand_value, GetZero(type));
345 auto ashr = AShr(operand_value, type->getIntegerBitWidth() - 1);
346 return Select(cmp, GetZero(type), Or(ashr, 1));
347 }
348 case HloOpcode::kNegate:
349 return Neg(operand_value);
350 case HloOpcode::kNot: {
351 auto type = op->shape().element_type();
352 if (type == PRED) {
353 // It is not sufficient to just call CreateNot() here because a PRED
354 // is represented as an i8 and the truth value is stored only in the
355 // bottom bit.
356 return b_->CreateZExt(Not(Trunc(operand_value, b_->getInt1Ty())),
357 llvm_ir::PrimitiveTypeToIrType(PRED, module_));
358 } else if (primitive_util::IsIntegralType(type)) {
359 return Not(operand_value);
360 }
361 return Unimplemented("unary op Not is not defined for type '%d'", type);
362 }
363 case HloOpcode::kPopulationCount: {
364 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctpop,
365 {operand_value},
366 {operand_value->getType()}, b_);
367 }
368 default:
369 return Unimplemented("unary integer op '%s'",
370 HloOpcodeString(op->opcode()));
371 }
372 }
373
EmitFloatUnaryOp(const HloInstruction * op,llvm::Value * operand_value)374 StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
375 const HloInstruction* op, llvm::Value* operand_value) {
376 switch (op->opcode()) {
377 case HloOpcode::kConvert: {
378 PrimitiveType from_type = op->operand(0)->shape().element_type();
379 PrimitiveType to_type = op->shape().element_type();
380 CHECK(primitive_util::IsFloatingPointType(from_type)) << from_type;
381 if (from_type == to_type) {
382 return operand_value;
383 }
384 if (from_type == BF16) {
385 TF_RET_CHECK(to_type != BF16);
386 operand_value = EmitBF16ToF32(operand_value, b_);
387 from_type = F32;
388 if (from_type == to_type) {
389 return operand_value;
390 }
391 }
392 if (primitive_util::IsComplexType(to_type)) {
393 PrimitiveType to_component_type =
394 primitive_util::ComplexComponentType(to_type);
395 if (from_type == to_component_type) {
396 return EmitComposeComplex(op, operand_value, nullptr);
397 }
398 return EmitComposeComplex(
399 op,
400 FPCast(operand_value,
401 llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)),
402 nullptr);
403 }
404 if (to_type == BF16) {
405 // Cast to F32 first. Other floating point formats are not supported by
406 // EmitReducePrecisionIR.
407 if (from_type != F32) {
408 operand_value = b_->CreateFPCast(
409 operand_value, llvm_ir::PrimitiveTypeToIrType(F32, module_));
410 }
411 return EmitF32ToBF16(operand_value, b_);
412 }
413 if (to_type == PRED) {
414 return b_->CreateZExt(
415 FCmpUNE(operand_value,
416 llvm::ConstantFP::get(operand_value->getType(), 0.0)),
417 llvm_ir::PrimitiveTypeToIrType(PRED, module_));
418 }
419 auto* to_ir_type = llvm_ir::PrimitiveTypeToIrType(to_type, module_);
420 if (primitive_util::IsFloatingPointType(to_type)) {
421 return FPCast(operand_value, to_ir_type);
422 }
423 auto* from_ir_type = llvm_ir::PrimitiveTypeToIrType(from_type, module_);
424 int to_width = primitive_util::BitWidth(to_type);
425 if (primitive_util::IsSignedIntegralType(to_type)) {
426 int64_t min_int = llvm::minIntN(to_width);
427 int64_t max_int = llvm::maxIntN(to_width);
428 auto zero_int = llvm::ConstantInt::get(to_ir_type, 0);
429 auto min_value_int = llvm::ConstantInt::get(to_ir_type, min_int);
430 auto max_value_int = llvm::ConstantInt::get(to_ir_type, max_int);
431 auto min_value_float = llvm::ConstantFP::get(from_ir_type, min_int);
432 auto max_value_float = llvm::ConstantFP::get(from_ir_type, max_int);
433 auto clamped = FPToSI(operand_value,
434 llvm_ir::PrimitiveTypeToIrType(to_type, module_));
435 // x <= static_cast<float>(INT_MIN) ? INT_MIN : ...
436 clamped = Select(FCmpOLE(operand_value, min_value_float), min_value_int,
437 clamped);
438 // x >= static_cast<float>(INT_MAX) ? INT_MAX : ...
439 clamped = Select(FCmpOGE(operand_value, max_value_float), max_value_int,
440 clamped);
441 // isnan(x) ? 0 : ...
442 clamped =
443 Select(FCmpUNO(operand_value, operand_value), zero_int, clamped);
444 return clamped;
445 }
446 if (primitive_util::IsUnsignedIntegralType(to_type)) {
447 uint64_t min_int = 0;
448 uint64_t max_int = llvm::maxUIntN(to_width);
449 auto min_value_int = llvm::ConstantInt::get(to_ir_type, min_int);
450 auto max_value_int = llvm::ConstantInt::get(to_ir_type, max_int);
451 auto min_value_float = llvm::ConstantFP::get(from_ir_type, min_int);
452 auto max_value_float = llvm::ConstantFP::get(from_ir_type, max_int);
453 auto clamped = FPToUI(operand_value,
454 llvm_ir::PrimitiveTypeToIrType(to_type, module_));
455 // (x <= 0.0 || isnan(x)) ? 0 : ...
456 clamped = Select(FCmpULE(operand_value, min_value_float), min_value_int,
457 clamped);
458 // x >= static_cast<float>(UINT_MAX) ? UINT_MAX : ...
459 clamped = Select(FCmpOGE(operand_value, max_value_float), max_value_int,
460 clamped);
461 return clamped;
462 }
463 return Unimplemented("unhandled conversion operation: %s => %s",
464 PrimitiveType_Name(from_type),
465 PrimitiveType_Name(to_type));
466 }
467 case HloOpcode::kBitcastConvert: {
468 PrimitiveType from_type = op->operand(0)->shape().element_type();
469 PrimitiveType to_type = op->shape().element_type();
470 CHECK(primitive_util::IsFloatingPointType(from_type));
471 if (from_type == to_type) {
472 return operand_value;
473 }
474 if (primitive_util::BitWidth(from_type) ==
475 primitive_util::BitWidth(to_type)) {
476 return BitCast(operand_value,
477 llvm_ir::PrimitiveTypeToIrType(to_type, module_));
478 }
479 return InvalidArgument(
480 "bitcast conversion from primitive type %s to %s with unequal "
481 "bit-widths (%u versus %u) ",
482 PrimitiveType_Name(from_type), PrimitiveType_Name(to_type),
483 primitive_util::BitWidth(from_type),
484 primitive_util::BitWidth(to_type));
485 }
486 case HloOpcode::kExp:
487 return EmitExp(op->shape().element_type(), operand_value, "");
488 case HloOpcode::kExpm1:
489 return EmitExpm1(op->shape().element_type(), operand_value);
490 case HloOpcode::kLog:
491 return EmitLog(op->shape().element_type(), operand_value);
492 case HloOpcode::kLog1p:
493 return EmitLog1p(op->shape().element_type(), operand_value);
494 case HloOpcode::kCos:
495 return EmitCos(op->shape().element_type(), operand_value);
496 case HloOpcode::kSin:
497 return EmitSin(op->shape().element_type(), operand_value);
498 case HloOpcode::kTanh:
499 return EmitTanh(op->shape().element_type(), operand_value);
500 case HloOpcode::kSqrt:
501 return EmitSqrt(op->shape().element_type(), operand_value);
502 case HloOpcode::kRsqrt:
503 return EmitRsqrt(op->shape().element_type(), operand_value);
504 case HloOpcode::kCbrt:
505 return EmitCbrt(op->shape().element_type(), operand_value);
506 case HloOpcode::kFloor:
507 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor,
508 {operand_value},
509 {operand_value->getType()}, b_);
510 case HloOpcode::kCeil:
511 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ceil,
512 {operand_value},
513 {operand_value->getType()}, b_);
514 case HloOpcode::kAbs:
515 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
516 {operand_value},
517 {operand_value->getType()}, b_);
518 case HloOpcode::kRoundNearestAfz:
519 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::round,
520 {operand_value},
521 {operand_value->getType()}, b_);
522 // TODO(b/238238423): llvm::Intrinsic::nearbyint is equivalent to roundeven
523 // as TF and JAX default to FE_TONEAREST. Call llvm::Intrinsic::roundeven
524 // instead once GPU emitter supports lowering LLVM.
525 case HloOpcode::kRoundNearestEven:
526 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nearbyint,
527 {operand_value},
528 {operand_value->getType()}, b_);
529 case HloOpcode::kSign: {
530 auto type = operand_value->getType();
531 auto zero = llvm::ConstantFP::get(type, 0.0);
532 auto ne0_i1 = FCmpONE(operand_value, zero);
533 auto ne0_float = UIToFP(ne0_i1, type);
534 llvm::Value* result = llvm_ir::EmitCallToIntrinsic(
535 llvm::Intrinsic::copysign, {ne0_float, operand_value},
536 {operand_value->getType()}, b_);
537 auto is_nan = FCmpUNO(operand_value, operand_value);
538 result = Select(is_nan, operand_value, result);
539 return result;
540 }
541 case HloOpcode::kIsFinite: {
542 // abs(x) o!= inf, this works because the comparison returns false if
543 // either operand is NaN.
544 auto type = operand_value->getType();
545 auto abs_value = llvm_ir::EmitCallToIntrinsic(
546 llvm::Intrinsic::fabs, {operand_value}, {type}, b_);
547 auto infinity = llvm::ConstantFP::getInfinity(type);
548 auto not_infinite = FCmpONE(abs_value, infinity);
549 return b_->CreateZExt(not_infinite,
550 llvm_ir::PrimitiveTypeToIrType(PRED, module_));
551 }
552 case HloOpcode::kNegate:
553 return FNeg(operand_value);
554 case HloOpcode::kReal:
555 return operand_value;
556 case HloOpcode::kImag:
557 return llvm::ConstantFP::get(operand_value->getType(), 0.0);
558 default:
559 return Unimplemented("unary floating-point op '%s'",
560 HloOpcodeString(op->opcode()));
561 }
562 }
563
EmitComplexUnaryOp(const HloInstruction * op,llvm::Value * operand_value)564 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
565 const HloInstruction* op, llvm::Value* operand_value) {
566 PrimitiveType input_type = op->operand(0)->shape().element_type();
567 PrimitiveType component_type =
568 primitive_util::IsComplexType(input_type)
569 ? primitive_util::ComplexComponentType(input_type)
570 : input_type;
571 switch (op->opcode()) {
572 case HloOpcode::kLog: {
573 return EmitComplexLog(op, operand_value);
574 }
575 case HloOpcode::kLog1p: {
576 // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
577 // log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1)
578 // log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1)
579 auto a = EmitExtractReal(operand_value);
580 auto b = EmitExtractImag(operand_value);
581 llvm::Type* llvm_ty = a->getType();
582 auto one = llvm::ConstantFP::get(llvm_ty, 1.0);
583 auto two = llvm::ConstantFP::get(llvm_ty, 2.0);
584 auto a_plus_one = FAdd(a, one);
585 auto sum_sq = FAdd(FAdd(FMul(a, a), FMul(two, a)), FMul(b, b));
586 TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog1p(component_type, sum_sq));
587 TF_ASSIGN_OR_RETURN(auto angle,
588 EmitAtan2(component_type, b, a_plus_one, ""));
589 auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
590 return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
591 }
592 case HloOpcode::kConvert: {
593 PrimitiveType from_type = op->operand(0)->shape().element_type();
594 TF_RET_CHECK(primitive_util::IsComplexType(from_type));
595 PrimitiveType to_type = op->shape().element_type();
596 TF_RET_CHECK(primitive_util::IsComplexType(to_type));
597 if (from_type == to_type) {
598 return operand_value;
599 }
600 PrimitiveType to_component_type =
601 primitive_util::ComplexComponentType(to_type);
602 auto to_ir_component_type =
603 llvm_ir::PrimitiveTypeToIrType(to_component_type, module_);
604 return EmitComposeComplex(
605 op, FPCast(EmitExtractReal(operand_value), to_ir_component_type),
606 FPCast(EmitExtractImag(operand_value), to_ir_component_type));
607 }
608 case HloOpcode::kExp: {
609 // e^(a+bi) = e^a*(cos(b)+sin(b)i)
610 TF_ASSIGN_OR_RETURN(
611 auto exp_a,
612 EmitExp(component_type, EmitExtractReal(operand_value), ""));
613 TF_ASSIGN_OR_RETURN(
614 auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
615 TF_ASSIGN_OR_RETURN(
616 auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
617 return EmitComposeComplex(op, FMul(exp_a, cos_b), FMul(exp_a, sin_b));
618 }
619 case HloOpcode::kExpm1: {
620 // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
621 TF_ASSIGN_OR_RETURN(
622 auto exp_a,
623 EmitExp(component_type, EmitExtractReal(operand_value), ""));
624 TF_ASSIGN_OR_RETURN(
625 auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
626 TF_ASSIGN_OR_RETURN(
627 auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
628 auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0);
629 auto real_result = FSub(FMul(exp_a, cos_b), one);
630 auto imag_result = FMul(exp_a, sin_b);
631 return EmitComposeComplex(op, real_result, imag_result);
632 }
633 case HloOpcode::kCos: {
634 // cos(z) = .5(e^(iz) + e^(-iz))
635 // cos(a+bi) = .5(e^(-b+ai) + e^(b-ai))
636 // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
637 // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(-a)+sin(-a)i))
638 // cos(-x) = cos(x) and sin(-x) = -sin(x), so
639 // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i))
640 // = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b))
641 auto a = EmitExtractReal(operand_value);
642 auto b = EmitExtractImag(operand_value);
643 auto type = a->getType();
644 TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b, ""));
645 auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
646 auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
647 TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
648 TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
649 return EmitComposeComplex(op,
650 FMul(cos_a, FAdd(half_exp_neg_b, half_exp_b)),
651 FMul(sin_a, FSub(half_exp_neg_b, half_exp_b)));
652 }
653 case HloOpcode::kSin: {
654 // sin(z) = .5i(e^(-iz) - e^(iz))
655 // sin(a+bi) = .5i(e^(-i(a+bi)) - e^(i(a+bi)))
656 // = .5i(e^(b-ai) - e^(-b+ai))
657 // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
658 // sin(a+bi) = 0.5i(e^b*(cos(-a)+sin(-a)i) - e^-b*(cos(a)+sin(a)i))
659 // = 0.5(e^b*(cos(-a)i-sin(-a)) - e^-b*(cos(a)i-sin(a)))
660 // cos(-x) = cos(x) and sin(-x) = -sin(x), so
661 // = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a)))
662 // = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b)
663 auto a = EmitExtractReal(operand_value);
664 auto b = EmitExtractImag(operand_value);
665 auto type = a->getType();
666 TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b, ""));
667 auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
668 auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
669 TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
670 TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
671 return EmitComposeComplex(op,
672 FMul(sin_a, FAdd(half_exp_b, half_exp_neg_b)),
673 FMul(cos_a, FSub(half_exp_b, half_exp_neg_b)));
674 }
675 case HloOpcode::kTanh: {
676 /*
677 tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x))
678 e^(a+bi) = e^a*(cos(b)+sin(b)i)
679 so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) /
680 (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a))
681 cos(b)=cos(-b), sin(-b)=-sin(b)
682 so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) /
683 (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a))
684 =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) /
685 (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a))
686 =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) /
687 (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a))
688 This is a complex division, so we can multiply by denom_conj/denom_conj
689 =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) *
690 (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) /
691 ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
692 =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) +
693 i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) /
694 ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
695 =(e^(2a)-e^(-2a) +
696 i*[cos(b)sin(b)(e^(2a)+2+e^(-2a))-cos(b)sin(b)(e^(2a)-2+e^(2a)))]
697 / (cos(b)^2*(e^(2a)+2+e^(-2a)) + sin(b)^2*(e^(2a)-2+e^(2a))
698 =(e^(2a)-e^(-2a) +
699 i*cos(b)sin(b)*[e^(2a)+2+e^(-2a)-e^(2a)+2-e^(-2a)]) /
700 ([cos(b)^2 + sin(b)^2][e^(2a)+e^(-2a)])+2*[cos(b)^2 - sin(b)^2])
701 =(e^(2a)-e^(-2a) + i*cos(b)sin(b)*4) /
702 (e^(2a)+e^(-2a)+2*[cos(b)^2 - sin(b)^2])
703 =(e^(2a)-e^(-2a) + i*[sin(2b)/2]*4) /
704 (e^(2a)+e^(-2a)+2*[cos(2b)])
705 =(e^(2a)-e^(-2a) + i*2*sin(2b)) / (e^(2a) + e^(-2a) + 2*cos(2b))
706 */
707 llvm::Value* a = EmitExtractReal(operand_value);
708 llvm::Value* b = EmitExtractImag(operand_value);
709
710 llvm::Type* type = a->getType();
711
712 llvm::Value* neg_one = llvm::ConstantFP::get(type, -1.F);
713 llvm::Value* two_a = FAdd(a, a);
714 llvm::Value* neg_2a = FMul(neg_one, two_a);
715
716 // When we are calculating the real numerator, e^(2a)-e^(-2a), for small
717 // values of `a`, we will get a ULP of 2^-23 using the exp function. Using
718 // expm1 to calculate e^(2a)-e^(-2a) = [e^(2a)-1] - [e^(-2a)-1] allows our
719 // ULP to be arbitrarily small. For larger values of `a`, calculating the
720 // numerator as Exp(2a)-Exp(-2a) vs Expm1(2a)-Expm1(-2a) return virtually
721 // identical results.
722 TF_ASSIGN_OR_RETURN(llvm::Value * exp_2a_m1,
723 EmitExpm1(component_type, two_a));
724 TF_ASSIGN_OR_RETURN(llvm::Value * exp_neg_2a_m1,
725 EmitExpm1(component_type, neg_2a));
726 llvm::Value* real_numerator = FSub(exp_2a_m1, exp_neg_2a_m1);
727
728 // We can use the identity cos(2b)+1 = cos(b)^2-sin(b)^2+cos(b)^2+sin(b)^2
729 // = 2cos(b)^2. This gives us the ability to be more precise when the
730 // denominator is close to zero.
731 TF_ASSIGN_OR_RETURN(llvm::Value * cos_b, EmitCos(component_type, b));
732 llvm::Value* four = llvm::ConstantFP::get(type, 4.F);
733 llvm::Value* cos_b_sq = FMul(cos_b, cos_b);
734 llvm::Value* two_cos_2b_p2 = FMul(cos_b_sq, four);
735
736 // Similarly we can compute sin(2b) with the formula sin(2b) =
737 // 2*sin(b)*cos(b).
738 TF_ASSIGN_OR_RETURN(llvm::Value * sin_b, EmitSin(component_type, b));
739 llvm::Value* imag_numerator = FMul(four, FMul(cos_b, sin_b));
740
741 // Expm1(x) is about x for small values of x, but exp_sum_m2 is about x^2
742 // for small value of x. As a result, due to floating point precision
743 // issues, x^2 is a better approximation than Expm1(x) + Expm1(x) for
744 // small values of x.
745 llvm::Value* a_sqr = FMul(a, a);
746 llvm::Value* use_approx_cutoff = llvm::ConstantFP::get(type, 1e-8);
747 llvm::Value* use_approx = FCmpOLT(a_sqr, use_approx_cutoff);
748
749 llvm::Value* exp_sum_m2 =
750 Select(use_approx, a_sqr, FAdd(exp_2a_m1, exp_neg_2a_m1));
751 llvm::Value* denom = FAdd(exp_sum_m2, two_cos_2b_p2);
752
753 // As `a` grows toward +inf and -inf, the real numerator will grow towards
754 // +inf and -inf respectively, while the denominator will always grow
755 // towards +inf. The result is real_numerator/denom = NaN, when it should
756 // equal +1 and -1 respectively. Therefore, if our denominator is +inf,
757 // we just hardcode the limits for the real numbers.
758 llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
759 llvm::Value* is_inf = FCmpOEQ(exp_sum_m2, inf);
760 llvm::Value* real_limit = llvm_ir::EmitCallToIntrinsic(
761 llvm::Intrinsic::copysign, {neg_one, a}, {type}, b_);
762
763 llvm::Value* real =
764 Select(is_inf, real_limit, FDiv(real_numerator, denom));
765 llvm::Value* imag = FDiv(imag_numerator, denom);
766
767 // The complex tanh functions have a few corner cases:
768 // 1. (+0, +0) => (+0, +0) - Handled normally
769 // 2. (x, +Inf) => (NaN, NaN) - See below
770 // 3. (x, NaN) => (NaN, NaN) - See below
771 // 4. (+inf, y) => (1, +0) - Handled normally
772 // 5. (+Inf, +Inf) => (1, +/-0) - See below
773 // 6. (+Inf, NaN) => (1, +/-0) - See below
774 // 7. (NaN, +0) => (NaN, +0) - See below
775 // 8. (NaN, y) => (NaN, NaN) - Handled normally
776 // 9. (NaN, NaN) => (NaN, NaN) - Handled normally
777 //
778 // For the cases that aren't handled normally:
779 // 2/3) Part of the calculation we do is that if exp(a) + exp(-a) = +inf,
780 // then we return (+/-1, +/-0). However, this is only true if we
781 // assume that a is infinity or b is finite. In the event that both a
782 // is finite and b is either +/-Inf or NaN, then our normal
783 // calculation would end up returing (+/-1, NaN), as opposed to (NaN,
784 // NaN).
785 // 5/6) We always calculate the imaginary value as sin(2b)/denominator.
786 // When the denominator is infinity, this assures us that the zero is
787 // the correct sign. However if our imaginary input results in
788 // sin(2b) = NaN, we calculate our imaginary result as NaN.
789 // 7) In the event that a is NaN, the denominator will be NaN.
790 // Therefore, the normal calculation gives (NaN, NaN) while we need
791 // (NaN, +0).
792 if (!(b_->getFastMathFlags().noNaNs() &&
793 b_->getFastMathFlags().noInfs())) {
794 llvm::Value* abs_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
795 {a}, {type}, b_);
796 llvm::Value* zero = llvm::ConstantFP::get(type, 0.F);
797 llvm::Value* nan = llvm::ConstantFP::getNaN(type);
798
799 llvm::Value* a_is_inf = FCmpOEQ(abs_a, inf);
800 llvm::Value* b_is_zero = FCmpOEQ(b, zero);
801
802 // imag_numerator = 2sin(2b), so sin(2b) is NaN if and only if
803 // imag_numerator is NaN.
804 llvm::Value* sin_2b_is_nan =
805 b_->CreateFCmpUNO(imag_numerator, imag_numerator);
806
807 llvm::Value* real_is_nan =
808 b_->CreateAnd(sin_2b_is_nan, b_->CreateNot(a_is_inf));
809 llvm::Value* imag_is_zero =
810 b_->CreateOr(b_is_zero, b_->CreateAnd(a_is_inf, sin_2b_is_nan));
811
812 real = Select(real_is_nan, nan, real);
813 imag = Select(imag_is_zero, zero, imag);
814 }
815
816 return EmitComposeComplex(op, real, imag);
817 }
818 case HloOpcode::kAbs: {
819 return EmitComplexAbs(component_type, operand_value);
820 }
821 case HloOpcode::kSign: { // Sign(c) = c / |c|
822 TF_ASSIGN_OR_RETURN(auto cplx_abs,
823 EmitComplexAbs(component_type, operand_value));
824 auto type = cplx_abs->getType();
825 auto zero = llvm::ConstantFP::get(type, 0.0);
826 auto oeq = FCmpOEQ(cplx_abs, zero);
827 return Select(
828 oeq, EmitComposeComplex(op, zero, zero),
829 EmitComposeComplex(op, FDiv(EmitExtractReal(operand_value), cplx_abs),
830 FDiv(EmitExtractImag(operand_value), cplx_abs)));
831 }
832 case HloOpcode::kSqrt: {
833 return EmitComplexSqrt(op, component_type, operand_value);
834 }
835 case HloOpcode::kRsqrt: {
836 return EmitComplexRsqrt(op, component_type, operand_value);
837 }
838 case HloOpcode::kCbrt: {
839 return EmitComplexCbrt(op, component_type, operand_value);
840 }
841 case HloOpcode::kNegate:
842 return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)),
843 FNeg(EmitExtractImag(operand_value)));
844 case HloOpcode::kReal:
845 return EmitExtractReal(operand_value);
846 case HloOpcode::kImag:
847 return EmitExtractImag(operand_value);
848 default:
849 return Unimplemented("unary complex op '%s'",
850 HloOpcodeString(op->opcode()));
851 }
852 }
853
EmitBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)854 StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
855 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
856 PrimitiveType operand_type = op->operand(0)->shape().element_type();
857 if (operand_type == PRED) {
858 return EmitPredBinaryOp(op, lhs_value, rhs_value);
859 } else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape())) {
860 return EmitIntegerBinaryOp(
861 op, lhs_value, rhs_value,
862 primitive_util::IsSignedIntegralType(operand_type));
863 } else if (primitive_util::IsComplexType(operand_type)) {
864 return EmitComplexBinaryOp(op, lhs_value, rhs_value);
865 } else {
866 return EmitFloatBinaryOp(op, lhs_value, rhs_value);
867 }
868 }
869
EmitFloatBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)870 StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
871 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
872 switch (op->opcode()) {
873 case HloOpcode::kComplex:
874 return EmitComposeComplex(op, lhs_value, rhs_value);
875 case HloOpcode::kAdd:
876 return FAdd(lhs_value, rhs_value, op->name());
877 case HloOpcode::kSubtract:
878 return FSub(lhs_value, rhs_value, op->name());
879 case HloOpcode::kMultiply:
880 return FMul(lhs_value, rhs_value, op->name());
881 case HloOpcode::kDivide:
882 return FDiv(lhs_value, rhs_value, op->name());
883 case HloOpcode::kRemainder:
884 return FRem(lhs_value, rhs_value, op->name());
885 // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
886 // comparisons always return false when one of the operands is NaN, whereas
887 // unordered comparisons return true.
888 //
889 // We use ordered comparisons for everything except kNe, where we use an
890 // unordered comparison. This makes x != y equivalent to !(x == y), and
891 // matches C++'s semantics.
892 case HloOpcode::kCompare: {
893 PrimitiveType operand_type = op->operand(0)->shape().element_type();
894 if (operand_type == BF16) {
895 lhs_value = EmitBF16ToF32(lhs_value, b_);
896 rhs_value = EmitBF16ToF32(rhs_value, b_);
897 }
898 switch (op->comparison_direction()) {
899 case ComparisonDirection::kEq:
900 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
901 rhs_value, b_, op->name());
902 case ComparisonDirection::kNe:
903 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
904 rhs_value, b_, op->name());
905 case ComparisonDirection::kLt:
906 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
907 rhs_value, b_, op->name());
908 case ComparisonDirection::kGt:
909 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
910 rhs_value, b_, op->name());
911 case ComparisonDirection::kLe:
912 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
913 rhs_value, b_, op->name());
914 case ComparisonDirection::kGe:
915 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
916 rhs_value, b_, op->name());
917 }
918 }
919 case HloOpcode::kMaximum:
920 return EmitFloatMax(lhs_value, rhs_value, op->name());
921 case HloOpcode::kMinimum:
922 return EmitFloatMin(lhs_value, rhs_value, op->name());
923 case HloOpcode::kPower:
924 return EmitPow(op->shape().element_type(), lhs_value, rhs_value,
925 op->name());
926 case HloOpcode::kAtan2:
927 return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value,
928 op->name());
929 default:
930 return Unimplemented("binary floating point op '%s'",
931 HloOpcodeString(op->opcode()));
932 }
933 }
934
935 // Using sqrt(a^2 + b^2) can cause overflow errors. Therefore we can use
936 // sqrt(a^2 + b^2) = sqrt(a^2 * (1 + b^2/a^2))
937 // = |a| * sqrt(1 + (b/a)^2)
938 // With the assumption that |a| >= |b|.
939 //
940 // This method returns the min, max, and sqrt term for this calculation. This is
941 // done to prevent potential overflow errors that can occur from multiplying the
942 // max with the sqrt term. (i.e. when calculating the sqrt of the absolute
943 // value, we can take the sqrt of the max and the sqrt term before multiplying
944 // them together.) If return_sqrt is false, it returns 1 + (b/a)^2 instead of
945 // sqrt(1 + (b/a)^2).
946 StatusOr<std::tuple<llvm::Value*, llvm::Value*, llvm::Value*>>
EmitComplexAbsHelper(PrimitiveType prim_type,llvm::Value * operand_value,bool return_sqrt)947 ElementalIrEmitter::EmitComplexAbsHelper(PrimitiveType prim_type,
948 llvm::Value* operand_value,
949 bool return_sqrt) {
950 llvm::Value* real = EmitExtractReal(operand_value);
951 llvm::Value* imag = EmitExtractImag(operand_value);
952 llvm::Value* abs_real = llvm_ir::EmitCallToIntrinsic(
953 llvm::Intrinsic::fabs, {real}, {real->getType()}, b_);
954 llvm::Value* abs_imag = llvm_ir::EmitCallToIntrinsic(
955 llvm::Intrinsic::fabs, {imag}, {imag->getType()}, b_);
956 llvm::Value* max = EmitFloatMax(abs_real, abs_imag, "");
957 llvm::Value* min = EmitFloatMin(abs_real, abs_imag, "");
958
959 llvm::Value* div = FDiv(min, max);
960 llvm::Value* div_sq = FMul(div, div);
961 llvm::Value* one = llvm::ConstantFP::get(max->getType(), 1);
962 llvm::Value* one_p_div_sq = FAdd(one, div_sq);
963 TF_ASSIGN_OR_RETURN(llvm::Value * sqrt, EmitSqrt(prim_type, one_p_div_sq));
964 return std::make_tuple(min, max, return_sqrt ? sqrt : one_p_div_sq);
965 }
966
EmitComplexAbs(PrimitiveType prim_type,llvm::Value * operand_value)967 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexAbs(
968 PrimitiveType prim_type, llvm::Value* operand_value) {
969 llvm::Value* min;
970 llvm::Value* max;
971 llvm::Value* sqrt;
972 TF_ASSIGN_OR_RETURN(
973 std::tie(min, max, sqrt),
974 EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/true));
975 llvm::Value* result = FMul(max, sqrt);
976 // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
977 // In such cases, we return `min` instead of `result`.
978 return Select(FCmpUNO(result, result), min, result);
979 }
980
981 // Calculates ComplexAbs in the same way, except using:
982 // sqrt(|a| * sqrt(1 + (b/a)^2)) = sqrt(|a|) * pow(1 + (b/a)^2, .25)
EmitSqrtComplexAbs(PrimitiveType prim_type,llvm::Value * operand_value)983 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrtComplexAbs(
984 PrimitiveType prim_type, llvm::Value* operand_value) {
985 llvm::Value* min;
986 llvm::Value* max;
987 llvm::Value* one_p_div_sq;
988 TF_ASSIGN_OR_RETURN(
989 std::tie(min, max, one_p_div_sq),
990 EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/false));
991 TF_ASSIGN_OR_RETURN(llvm::Value * sqrt_max, EmitSqrt(prim_type, max));
992 TF_ASSIGN_OR_RETURN(llvm::Value * pow,
993 EmitPow(prim_type, one_p_div_sq,
994 llvm::ConstantFP::get(max->getType(), .25), ""));
995 llvm::Value* result = FMul(sqrt_max, pow);
996 // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
997 // In such cases, we return `min` instead of `result`.
998 return Select(FCmpUNO(result, result), min, result);
999 }
1000
1001 // Calculates ComplexAbs in the same way, except using:
1002 // rsqrt(|a| * sqrt(1 + (b/a)^2)) = rsqrt(|a|) * rsqrt(sqrt(1 + (b/a)^2))
EmitRsqrtComplexAbs(PrimitiveType prim_type,llvm::Value * operand_value)1003 StatusOr<llvm::Value*> ElementalIrEmitter::EmitRsqrtComplexAbs(
1004 PrimitiveType prim_type, llvm::Value* operand_value) {
1005 llvm::Value* min;
1006 llvm::Value* max;
1007 llvm::Value* sqrt;
1008 TF_ASSIGN_OR_RETURN(
1009 std::tie(min, max, sqrt),
1010 EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/true));
1011 TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_max, EmitRsqrt(prim_type, max));
1012 TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_sqrt, EmitRsqrt(prim_type, sqrt));
1013 llvm::Value* result = FMul(rsqrt_max, rsqrt_sqrt);
1014 TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_min, EmitRsqrt(prim_type, min));
1015 // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
1016 // In such cases, we return rsqrt(min) instead of `result`.
1017 return Select(FCmpUNO(result, result), rsqrt_min, result);
1018 }
1019
EmitComplexAdd(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1020 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexAdd(
1021 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1022 return EmitComposeComplex(
1023 op, FAdd(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
1024 FAdd(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
1025 }
1026
EmitComplexSubtract(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1027 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexSubtract(
1028 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1029 return EmitComposeComplex(
1030 op, FSub(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
1031 FSub(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
1032 }
1033
EmitComplexMultiply(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1034 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexMultiply(
1035 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1036 return EmitComposeComplex(
1037 op,
1038 FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
1039 FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))),
1040 FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
1041 FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))));
1042 }
1043
EmitComplexDivide(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1044 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexDivide(
1045 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1046 // Division of complex numbers is implemented here, taking into account
1047 // over/underflow, NaN and Inf values.
1048 auto a_r = EmitExtractReal(lhs_value);
1049 auto a_i = EmitExtractImag(lhs_value);
1050 auto b_r = EmitExtractReal(rhs_value);
1051 auto b_i = EmitExtractImag(rhs_value);
1052 auto type = a_r->getType();
1053
1054 // Smith's algorithm to divide complex numbers. It is just a bit smarter
1055 // way to compute the following formula:
1056 // (a_r + a_i * i) / (b_r + b_i * i)
1057 // = (a_r + a_i * i) (b_r - b_i * i) / ((b_r + b_i * i)(b_r - b_i * i))
1058 // = ((a_r * b_r + a_i * b_i) + (a_i * b_r - a_r * b_i) * i) / ||b||^2
1059 //
1060 // Depending on whether |b_r| < |b_i| we compute either
1061 // b_r_b_i_ratio = b_r / b_i
1062 // b_r_b_i_denom = b_i + b_r * b_r_b_i_ratio
1063 // c_r = (a_r * b_r_b_i_ratio + a_i ) / b_r_b_i_denom
1064 // c_i = (a_i * b_r_b_i_ratio - a_r ) / b_r_b_i_denom
1065 //
1066 // or
1067 //
1068 // b_i_b_r_ratio = b_i / b_r
1069 // b_i_b_r_denom = b_r + b_i * b_i_b_r_ratio
1070 // c_r = (a_r + a_i * b_i_b_r_ratio ) / b_i_b_r_denom
1071 // c_i = (a_i - a_r * b_i_b_r_ratio ) / b_i_b_r_denom
1072 //
1073 // See https://dl.acm.org/citation.cfm?id=368661 for more details.
1074 auto b_r_b_i_ratio = FDiv(b_r, b_i);
1075 auto b_r_b_i_denom = FAdd(b_i, FMul(b_r_b_i_ratio, b_r));
1076 auto b_i_b_r_ratio = FDiv(b_i, b_r);
1077 auto b_i_b_r_denom = FAdd(b_r, FMul(b_i_b_r_ratio, b_i));
1078
1079 auto b_r_abs =
1080 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {b_r}, {type}, b_);
1081 auto b_i_abs =
1082 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {b_i}, {type}, b_);
1083 auto b_r_lt_b_i = FCmpOLT(b_r_abs, b_i_abs);
1084 auto c_r = Select(b_r_lt_b_i,
1085 FDiv(FAdd(FMul(b_r_b_i_ratio, a_r), a_i), b_r_b_i_denom),
1086 FDiv(FAdd(FMul(b_i_b_r_ratio, a_i), a_r), b_i_b_r_denom));
1087 auto c_i = Select(b_r_lt_b_i,
1088 FDiv(FSub(FMul(b_r_b_i_ratio, a_i), a_r), b_r_b_i_denom),
1089 FDiv(FSub(a_i, FMul(b_i_b_r_ratio, a_r)), b_i_b_r_denom));
1090 auto result = EmitComposeComplex(op, c_r, c_i);
1091
1092 // Consider corner cases, if the result is (NaN, NaN).
1093 auto zero = llvm::ConstantFP::get(type, 0.0);
1094 auto one = llvm::ConstantFP::get(type, 1.0);
1095 auto inf = llvm::ConstantFP::getInfinity(type);
1096
1097 // Case 1. Zero denominator.
1098 auto zero_denominator =
1099 And(And(FCmpOEQ(b_r_abs, zero), FCmpOEQ(b_i_abs, zero)),
1100 Or(Not(FCmpUNO(a_r, zero)), Not(FCmpUNO(a_i, zero))));
1101 auto inf_with_sign_of_b_r = llvm_ir::EmitCallToIntrinsic(
1102 llvm::Intrinsic::copysign, {inf, b_r}, {type}, b_);
1103 auto zero_denominator_result = EmitComposeComplex(
1104 op, FMul(inf_with_sign_of_b_r, a_r), FMul(inf_with_sign_of_b_r, a_i));
1105
1106 // Case 2. Infinite numerator, finite denominator.
1107 auto b_r_finite = FCmpONE(b_r_abs, inf);
1108 auto b_i_finite = FCmpONE(b_i_abs, inf);
1109 auto a_r_abs =
1110 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {a_r}, {type}, b_);
1111 auto a_i_abs =
1112 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {a_i}, {type}, b_);
1113 auto a_r_infinite = FCmpOEQ(a_r_abs, inf);
1114 auto a_i_infinite = FCmpOEQ(a_i_abs, inf);
1115 auto inf_num_finite_denom =
1116 And(Or(a_r_infinite, a_i_infinite), And(b_r_finite, b_i_finite));
1117
1118 auto a_r_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1119 llvm::Intrinsic::copysign, {Select(a_r_infinite, one, zero), a_r}, {type},
1120 b_);
1121 auto a_i_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1122 llvm::Intrinsic::copysign, {Select(a_i_infinite, one, zero), a_i}, {type},
1123 b_);
1124 auto inf_num_finite_denom_result = EmitComposeComplex(
1125 op,
1126 FMul(inf,
1127 FAdd(FMul(a_r_inf_with_sign, b_r), FMul(a_i_inf_with_sign, b_i))),
1128 FMul(inf,
1129 FSub(FMul(a_i_inf_with_sign, b_r), FMul(a_r_inf_with_sign, b_i))));
1130
1131 // Case 3. Finite numerator, infinite denominator.
1132 auto a_r_finite = FCmpONE(a_r_abs, inf);
1133 auto a_i_finite = FCmpONE(a_i_abs, inf);
1134 auto b_r_infinite = FCmpOEQ(b_r_abs, inf);
1135 auto b_i_infinite = FCmpOEQ(b_i_abs, inf);
1136 auto finite_num_inf_denom =
1137 And(Or(b_r_infinite, b_i_infinite), And(a_r_finite, a_i_finite));
1138
1139 auto b_r_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1140 llvm::Intrinsic::copysign, {Select(b_r_infinite, one, zero), b_r}, {type},
1141 b_);
1142 auto b_i_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1143 llvm::Intrinsic::copysign, {Select(b_i_infinite, one, zero), b_i}, {type},
1144 b_);
1145 auto finite_num_inf_denom_result = EmitComposeComplex(
1146 op,
1147 FMul(zero,
1148 FAdd(FMul(a_r, b_r_inf_with_sign), FMul(a_i, b_i_inf_with_sign))),
1149 FMul(zero,
1150 FSub(FMul(a_i, b_r_inf_with_sign), FMul(a_r, b_i_inf_with_sign))));
1151
1152 auto c_nan = And(FCmpUNO(c_r, zero), FCmpUNO(c_i, zero));
1153 return Select(c_nan,
1154 Select(zero_denominator, zero_denominator_result,
1155 Select(inf_num_finite_denom, inf_num_finite_denom_result,
1156 Select(finite_num_inf_denom,
1157 finite_num_inf_denom_result, result))),
1158 result);
1159 }
1160
EmitComplexLog(const HloInstruction * op,llvm::Value * operand_value)1161 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexLog(
1162 const HloInstruction* op, llvm::Value* operand_value) {
1163 // log(a+bi) = log(abs(a+bi)) + i*atan2(b,a)
1164 PrimitiveType component_type =
1165 primitive_util::ComplexComponentType(op->shape().element_type());
1166 auto a = EmitExtractReal(operand_value);
1167 auto b = EmitExtractImag(operand_value);
1168 TF_ASSIGN_OR_RETURN(llvm::Value * angle, EmitAtan2(component_type, b, a, ""));
1169 TF_ASSIGN_OR_RETURN(llvm::Value * abs,
1170 EmitComplexAbs(component_type, operand_value));
1171 TF_ASSIGN_OR_RETURN(llvm::Value * log_abs, EmitLog(component_type, abs));
1172 return EmitComposeComplex(op, log_abs, angle);
1173 }
1174
1175 // Using our EmitComplexPower formula, but setting c=0.5 and d=0, we get:
1176 // e^[ln(r)*c - t*d] * [cos(ln(r)*d + t*c) + i*sin(ln(r)*d + t*c)]
1177 // = e^[ln(r)*0.5] * [cos(t*0.5) + i*sin(t*0.5)]
1178 // = r^0.5 * [cos(t/2) + i*sin(t/2)]
1179 // = sqrt(r) * [cos(t/2) + i*sin(t/2)]
1180 // where r = |a+bi| and t = atan2(b,a)
1181 // TODO(bixia): See doc for implementation without atan2.
EmitComplexSqrt(const HloInstruction * op,PrimitiveType prim_type,llvm::Value * operand_value)1182 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexSqrt(
1183 const HloInstruction* op, PrimitiveType prim_type,
1184 llvm::Value* operand_value) {
1185 llvm::Type* type = static_cast<llvm::StructType*>(operand_value->getType())
1186 ->getElementType(0);
1187
1188 TF_ASSIGN_OR_RETURN(llvm::Value * r,
1189 EmitSqrtComplexAbs(prim_type, operand_value));
1190
1191 llvm::Value* a = EmitExtractReal(operand_value);
1192 llvm::Value* b = EmitExtractImag(operand_value);
1193 TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a, ""));
1194
1195 llvm::Value* c = llvm::ConstantFP::get(type, 0.5);
1196 llvm::Value* angle = FMul(t, c);
1197 TF_ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle));
1198 TF_ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle));
1199
1200 llvm::Value* real_part;
1201 llvm::Value* imag_part;
1202
1203 llvm::Value* zero = llvm::ConstantFP::get(type, 0);
1204
1205 if (!(b_->getFastMathFlags().noNaNs() && b_->getFastMathFlags().noInfs())) {
1206 llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
1207 llvm::Value* neg_inf = llvm::ConstantFP::getInfinity(type, true);
1208 llvm::Value* nan = llvm::ConstantFP::getNaN(type);
1209 llvm::Value* abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
1210 {b}, {b->getType()}, b_);
1211
1212 real_part = Select(Or(FCmpOEQ(abs_b, inf), FCmpOEQ(a, inf)), inf,
1213 Select(And(FCmpOEQ(a, neg_inf), FCmpONE(abs_b, inf)),
1214 zero, FMul(r, cos)));
1215
1216 llvm::Value* b_signed_inf = llvm_ir::EmitCallToIntrinsic(
1217 llvm::Intrinsic::copysign, {inf, b}, {b->getType()}, b_);
1218 imag_part =
1219 Select(Or(FCmpOEQ(abs_b, inf), FCmpOEQ(a, neg_inf)), b_signed_inf,
1220 Select(FCmpUNO(r, r), nan,
1221 Select(FCmpOEQ(sin, zero), sin, FMul(r, sin))));
1222 } else {
1223 real_part = FMul(r, cos);
1224 imag_part = Select(FCmpOEQ(sin, zero), sin, FMul(r, sin));
1225 }
1226
1227 return Select(FCmpOEQ(r, zero), EmitComposeComplex(op, zero, zero),
1228 EmitComposeComplex(op, real_part, imag_part));
1229 }
1230
1231 // Similar to Sqrt, we can use our EmitComplexPower formula, but set
1232 // c=-0.5 and d=0. We get:
1233 // e^[ln(r)*c - t*d] * [cos(ln(r)*d + t*c) + i*sin(ln(r)*d + t*c)]
1234 // = e^[ln(r)*-0.5] * [cos(t*-0.5) + i*sin(t*-0.5)]
1235 // = r^(-0.5) * [cos(-t/2) + i*sin(-t/2)]
1236 // = rsqrt(r) * [cos(-t/2) + i*sin(-t/2)]
1237 // where r = |a+bi| and t = atan2(b,a).
EmitComplexRsqrt(const HloInstruction * op,PrimitiveType prim_type,llvm::Value * operand_value)1238 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexRsqrt(
1239 const HloInstruction* op, PrimitiveType prim_type,
1240 llvm::Value* operand_value) {
1241 llvm::Type* type = static_cast<llvm::StructType*>(operand_value->getType())
1242 ->getElementType(0);
1243
1244 TF_ASSIGN_OR_RETURN(llvm::Value * r,
1245 EmitRsqrtComplexAbs(prim_type, operand_value));
1246
1247 llvm::Value* a = EmitExtractReal(operand_value);
1248 llvm::Value* b = EmitExtractImag(operand_value);
1249 TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a, ""));
1250
1251 llvm::Value* c = llvm::ConstantFP::get(type, -0.5);
1252 llvm::Value* angle = FMul(t, c);
1253 TF_ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle));
1254 TF_ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle));
1255
1256 llvm::Value* real_part = FMul(r, cos);
1257 llvm::Value* imag_part = FMul(r, sin);
1258
1259 if (!(b_->getFastMathFlags().noNaNs() && b_->getFastMathFlags().noInfs())) {
1260 llvm::Value* zero = llvm::ConstantFP::get(type, 0);
1261 llvm::Value* neg_one = llvm::ConstantFP::get(type, -1);
1262 llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
1263 llvm::Value* nan = llvm::ConstantFP::getNaN(type);
1264 // llvm::Value* neg_inf = llvm::ConstantFP::getInfinity(type, true);
1265 llvm::Value* a_signed_zero = llvm_ir::EmitCallToIntrinsic(
1266 llvm::Intrinsic::copysign, {zero, a}, {a->getType()}, b_);
1267 llvm::Value* b_signed_zero = llvm_ir::EmitCallToIntrinsic(
1268 llvm::Intrinsic::copysign, {zero, b}, {b->getType()}, b_);
1269 llvm::Value* neg_b_signed_zero = FMul(b_signed_zero, neg_one);
1270
1271 llvm::Value* abs_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
1272 {a}, {a->getType()}, b_);
1273 llvm::Value* abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
1274 {b}, {b->getType()}, b_);
1275
1276 llvm::Value* is_zero_zero = And(FCmpOEQ(b, zero), FCmpOEQ(a, zero));
1277 real_part = Select(
1278 is_zero_zero, inf,
1279 Select(Or(And(FCmpOEQ(abs_b, inf), FCmpUNO(a, a)), FCmpOEQ(abs_a, inf)),
1280 a_signed_zero, FMul(r, cos)));
1281 imag_part = Select(
1282 is_zero_zero, nan,
1283 Select(Or(And(FCmpOEQ(abs_b, inf), FCmpUNO(a, a)), FCmpOEQ(abs_a, inf)),
1284 neg_b_signed_zero, FMul(r, sin)));
1285 } else {
1286 llvm::Value* zero = llvm::ConstantFP::get(type, 0);
1287 llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
1288 llvm::Value* nan = llvm::ConstantFP::getNaN(type);
1289
1290 llvm::Value* is_zero_zero = And(FCmpOEQ(b, zero), FCmpOEQ(a, zero));
1291 real_part = Select(is_zero_zero, inf, FMul(r, cos));
1292 imag_part = Select(is_zero_zero, nan, FMul(r, sin));
1293 }
1294
1295 return EmitComposeComplex(op, real_part, imag_part);
1296 }
1297
1298 //
1299 // Using EmitComplexPower with c=1.0/3.0 and d=0
EmitComplexCbrt(const HloInstruction * op,PrimitiveType prim_type,llvm::Value * operand_value)1300 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexCbrt(
1301 const HloInstruction* op, PrimitiveType prim_type,
1302 llvm::Value* operand_value) {
1303 auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1304 auto third = llvm::ConstantFP::get(type, 1.0 / 3.0);
1305 auto zero = llvm::ConstantFP::get(type, 0);
1306 llvm::Value* a = EmitExtractReal(operand_value);
1307 llvm::Value* b = EmitExtractImag(operand_value);
1308 return EmitComplexPower(op, a, b, third, zero);
1309 }
1310
1311 // (a+bi)^(c+di) =
1312 // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
1313 // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
EmitComplexPower(const HloInstruction * op,llvm::Value * a,llvm::Value * b,llvm::Value * c,llvm::Value * d)1314 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexPower(
1315 const HloInstruction* op, llvm::Value* a, llvm::Value* b, llvm::Value* c,
1316 llvm::Value* d) {
1317 PrimitiveType component_type =
1318 primitive_util::ComplexComponentType(op->shape().element_type());
1319 auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b));
1320 auto zero = llvm::ConstantFP::get(a->getType(), 0);
1321 auto one_half = llvm::ConstantFP::get(a->getType(), 0.5);
1322 auto one = llvm::ConstantFP::get(a->getType(), 1);
1323 auto half_c = FMul(one_half, c);
1324
1325 TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c,
1326 EmitPow(component_type, aa_p_bb, half_c, ""));
1327
1328 auto neg_d = FNeg(d);
1329 TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a, ""));
1330 auto neg_d_arg_lhs = FMul(neg_d, arg_lhs);
1331 TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs,
1332 EmitExp(component_type, neg_d_arg_lhs, ""));
1333 auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
1334 TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb));
1335 auto half_d = FMul(one_half, d);
1336 auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb));
1337 TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q));
1338 TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q));
1339 // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
1340 // Branch Cuts for Complex Elementary Functions or Much Ado About
1341 // Nothing's Sign Bit, W. Kahan, Section 10.
1342 return Select(
1343 And(And(FCmpOEQ(aa_p_bb, zero), FCmpOEQ(d, zero)), FCmpOLE(zero, c)),
1344 EmitComposeComplex(op, Select(FCmpOEQ(zero, c), one, zero), zero),
1345 EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q)));
1346 }
1347
EmitComplexBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1348 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
1349 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1350 switch (op->opcode()) {
1351 case HloOpcode::kAdd:
1352 return EmitComplexAdd(op, lhs_value, rhs_value);
1353 case HloOpcode::kSubtract:
1354 return EmitComplexSubtract(op, lhs_value, rhs_value);
1355 case HloOpcode::kMultiply:
1356 return EmitComplexMultiply(op, lhs_value, rhs_value);
1357 case HloOpcode::kDivide: {
1358 return EmitComplexDivide(op, lhs_value, rhs_value);
1359 }
1360 // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
1361 // comparisons always return false when one of the operands is NaN, whereas
1362 // unordered comparisons return true.
1363 //
1364 // We use ordered comparisons for everything except kNe, where we use an
1365 // unordered comparison. This makes x != y equivalent to !(x == y), and
1366 // matches C++'s semantics.
1367 case HloOpcode::kCompare: {
1368 switch (op->comparison_direction()) {
1369 case ComparisonDirection::kEq:
1370 return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
1371 EmitExtractReal(lhs_value),
1372 EmitExtractReal(rhs_value), b_),
1373 llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
1374 EmitExtractImag(lhs_value),
1375 EmitExtractImag(rhs_value), b_));
1376 case ComparisonDirection::kNe:
1377 return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
1378 EmitExtractReal(lhs_value),
1379 EmitExtractReal(rhs_value), b_),
1380 llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
1381 EmitExtractImag(lhs_value),
1382 EmitExtractImag(rhs_value), b_));
1383 default:
1384 return Unimplemented(
1385 "complex comparison '%s'",
1386 ComparisonDirectionToString(op->comparison_direction()));
1387 }
1388 }
1389 case HloOpcode::kPower: {
1390 auto a = EmitExtractReal(lhs_value);
1391 auto b = EmitExtractImag(lhs_value);
1392 auto c = EmitExtractReal(rhs_value);
1393 auto d = EmitExtractImag(rhs_value);
1394 return EmitComplexPower(op, a, b, c, d);
1395 }
1396 case HloOpcode::kAtan2: {
1397 // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
1398 auto y = lhs_value;
1399 auto x = rhs_value;
1400 TF_ASSIGN_OR_RETURN(auto x_squared, EmitComplexMultiply(op, x, x));
1401 TF_ASSIGN_OR_RETURN(auto y_squared, EmitComplexMultiply(op, y, y));
1402 TF_ASSIGN_OR_RETURN(auto x_squared_plus_y_squared,
1403 EmitComplexAdd(op, x_squared, y_squared));
1404 auto component_type =
1405 primitive_util::ComplexComponentType(op->shape().element_type());
1406 TF_ASSIGN_OR_RETURN(
1407 auto sqrt_x_squared_plus_y_squared,
1408 EmitComplexSqrt(op, component_type, x_squared_plus_y_squared));
1409 auto type = llvm_ir::PrimitiveTypeToIrType(component_type, module_);
1410 auto zero = llvm::ConstantFP::get(type, 0.0);
1411 auto one = llvm::ConstantFP::get(type, 1.0);
1412 auto i = EmitComposeComplex(op, zero, one);
1413 TF_ASSIGN_OR_RETURN(auto i_times_y, EmitComplexMultiply(op, i, y));
1414 TF_ASSIGN_OR_RETURN(auto x_plus_iy, EmitComplexAdd(op, x, i_times_y));
1415 TF_ASSIGN_OR_RETURN(
1416 auto div_result,
1417 EmitComplexDivide(op, x_plus_iy, sqrt_x_squared_plus_y_squared));
1418 TF_ASSIGN_OR_RETURN(auto log_result, EmitComplexLog(op, div_result));
1419 auto negative_one = llvm::ConstantFP::get(type, -1.0);
1420 auto negative_i = EmitComposeComplex(op, zero, negative_one);
1421 return EmitComplexMultiply(op, negative_i, log_result);
1422 }
1423 default:
1424 return Unimplemented("binary complex op '%s'",
1425 HloOpcodeString(op->opcode()));
1426 }
1427 }
1428
EmitFloatMax(llvm::Value * lhs_value,llvm::Value * rhs_value,absl::string_view name)1429 llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
1430 llvm::Value* rhs_value,
1431 absl::string_view name) {
1432 return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max(), name);
1433 }
1434
EmitFloatMin(llvm::Value * lhs_value,llvm::Value * rhs_value,absl::string_view name)1435 llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
1436 llvm::Value* rhs_value,
1437 absl::string_view name) {
1438 return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max(), name);
1439 }
1440
EmitLog(PrimitiveType prim_type,llvm::Value * value)1441 StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
1442 llvm::Value* value) {
1443 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value},
1444 {value->getType()}, b_);
1445 }
1446
EmitLog1p(PrimitiveType prim_type,llvm::Value * value)1447 StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
1448 llvm::Value* value) {
1449 auto x = value;
1450 auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1451 auto one = llvm::ConstantFP::get(type, 1.0);
1452 auto negative_half = llvm::ConstantFP::get(type, -0.5);
1453 // When x is large, the naive evaluation of ln(x + 1) is more
1454 // accurate than the Taylor series.
1455 TF_ASSIGN_OR_RETURN(auto for_large_x, EmitLog(prim_type, FAdd(x, one)));
1456 // When x is small, (defined to be less than sqrt(2) / 2), use a rational
1457 // approximation. The approximation below is based on one from the Cephes
1458 // Mathematical Library.
1459 //
1460 // sqrt(2) - 1.
1461 const auto kAntilogarithmIsSmallThreshold = 0.41421356237309504880;
1462
1463 static const std::array<double, 7> kDenominatorCoeffs{
1464 1.,
1465 1.5062909083469192043167E1,
1466 8.3047565967967209469434E1,
1467 2.2176239823732856465394E2,
1468 3.0909872225312059774938E2,
1469 2.1642788614495947685003E2,
1470 6.0118660497603843919306E1,
1471 };
1472
1473 static const std::array<double, 7> kNumeratorCoeffs{
1474 4.5270000862445199635215E-5, 4.9854102823193375972212E-1,
1475 6.5787325942061044846969E0, 2.9911919328553073277375E1,
1476 6.0949667980987787057556E1, 5.7112963590585538103336E1,
1477 2.0039553499201281259648E1,
1478 };
1479
1480 auto x_squared = FMul(x, x);
1481 TF_ASSIGN_OR_RETURN(auto denominator,
1482 EvaluatePolynomial(type, x, kDenominatorCoeffs));
1483 TF_ASSIGN_OR_RETURN(auto numerator,
1484 EvaluatePolynomial(type, x, kNumeratorCoeffs));
1485 auto for_small_x = FDiv(numerator, denominator);
1486 for_small_x = FMul(FMul(x, x_squared), for_small_x);
1487 for_small_x = FAdd(FMul(negative_half, x_squared), for_small_x);
1488 for_small_x = FAdd(x, for_small_x);
1489
1490 auto abs_x =
1491 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
1492 auto x_is_small = FCmpOLT(
1493 abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold));
1494 return Select(x_is_small, for_small_x, for_large_x);
1495 }
1496
EmitSqrt(PrimitiveType,llvm::Value * value)1497 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrt(PrimitiveType,
1498 llvm::Value* value) {
1499 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {value},
1500 {value->getType()}, b_);
1501 }
1502
EmitRsqrt(PrimitiveType prim_type,llvm::Value * value)1503 StatusOr<llvm::Value*> ElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type,
1504 llvm::Value* value) {
1505 TF_ASSIGN_OR_RETURN(auto sqrt, EmitSqrt(prim_type, value));
1506 return FDiv(llvm::ConstantFP::get(sqrt->getType(), 1.0), sqrt);
1507 }
1508
EmitSin(PrimitiveType prim_type,llvm::Value * value)1509 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSin(PrimitiveType prim_type,
1510 llvm::Value* value) {
1511 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value},
1512 {value->getType()}, b_);
1513 }
1514
EmitCos(PrimitiveType prim_type,llvm::Value * value)1515 StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type,
1516 llvm::Value* value) {
1517 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value},
1518 {value->getType()}, b_);
1519 }
1520
EmitExp(PrimitiveType prim_type,llvm::Value * value,absl::string_view name)1521 StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
1522 llvm::Value* value,
1523 absl::string_view name) {
1524 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value},
1525 {value->getType()}, b_, name);
1526 }
1527
EmitExpm1(PrimitiveType prim_type,llvm::Value * value)1528 StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
1529 llvm::Value* value) {
1530 auto x = value;
1531 auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1532 auto one = llvm::ConstantFP::get(type, 1.0);
1533 auto half = llvm::ConstantFP::get(type, 0.5);
1534 auto zero = llvm::ConstantFP::get(type, 0.0);
1535
1536 // expm1(x) == tanh(x/2)*(exp(x)+1)
1537 // x/2 can underflow, if it does we approximate expm1 with x.
1538 auto x_over_two = FMul(x, half);
1539 auto x_over_two_is_zero = FCmpOEQ(x_over_two, zero);
1540 auto abs_x =
1541 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {x}, {type}, b_);
1542 // Use a naive exp(x)-1 calculation if |x| is > 0.5
1543 auto x_magnitude_is_large = FCmpOGT(abs_x, half);
1544 TF_ASSIGN_OR_RETURN(auto tanh_of_x_over_two, EmitTanh(prim_type, x_over_two));
1545 TF_ASSIGN_OR_RETURN(auto exp_of_x, EmitExp(prim_type, x, ""));
1546 auto exp_of_x_plus_one = FAdd(exp_of_x, one);
1547 auto exp_of_x_minus_one = FSub(exp_of_x, one);
1548 auto expm1_of_x = FMul(tanh_of_x_over_two, exp_of_x_plus_one);
1549 expm1_of_x = Select(x_magnitude_is_large, exp_of_x_minus_one, expm1_of_x);
1550 expm1_of_x = Select(x_over_two_is_zero, x, expm1_of_x);
1551 return expm1_of_x;
1552 }
1553
EmitPow(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs,absl::string_view name)1554 StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
1555 llvm::Value* lhs,
1556 llvm::Value* rhs,
1557 absl::string_view name) {
1558 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs},
1559 {lhs->getType()}, b_, name);
1560 }
1561
EmitCbrt(PrimitiveType prim_type,llvm::Value * value)1562 StatusOr<llvm::Value*> ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type,
1563 llvm::Value* value) {
1564 auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1565 auto third = llvm::ConstantFP::get(type, 1.0 / 3.0);
1566 auto abs_value =
1567 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
1568 TF_ASSIGN_OR_RETURN(llvm::Value * abs_res,
1569 EmitPow(prim_type, abs_value, third, ""));
1570 auto signed_res = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
1571 {abs_res, value}, {type}, b_);
1572 return signed_res;
1573 }
1574
EmitAtan2(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value *,absl::string_view)1575 StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(
1576 PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* /*rhs*/,
1577 absl::string_view /*name*/) {
1578 return Unimplemented("atan2");
1579 }
1580
EmitTanh(PrimitiveType prim_type,llvm::Value * value)1581 StatusOr<llvm::Value*> ElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
1582 llvm::Value* value) {
1583 return Unimplemented("tanh");
1584 }
1585
EmitReducePrecision(const HloInstruction * hlo,llvm::Value * x)1586 StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
1587 const HloInstruction* hlo, llvm::Value* x) {
1588 return EmitReducePrecisionIR(
1589 /*src_ty=*/hlo->operand(0)->shape().element_type(), x,
1590 /*dest_exponent_bits=*/hlo->exponent_bits(),
1591 /*dest_mantissa_bits=*/hlo->mantissa_bits(),
1592 /*quiet_nans=*/false, b_);
1593 }
1594
SaturateShiftIfNecessary(llvm::IRBuilder<> * b,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * shift_result,bool saturate_to_sign_bit)1595 static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b,
1596 llvm::Value* lhs, llvm::Value* rhs,
1597 llvm::Value* shift_result,
1598 bool saturate_to_sign_bit) {
1599 llvm::IntegerType* integer_type =
1600 llvm::cast<llvm::IntegerType>(lhs->getType());
1601 unsigned integer_bitsize = integer_type->getBitWidth();
1602 llvm::ConstantInt* integer_bitsize_constant =
1603 llvm::ConstantInt::get(integer_type, integer_bitsize);
1604 llvm::ConstantInt* zero = llvm::ConstantInt::get(integer_type, 0);
1605 llvm::ConstantInt* minus_one = llvm::ConstantInt::get(integer_type, -1);
1606 llvm::Value* saturated_value;
1607 if (saturate_to_sign_bit) {
1608 saturated_value =
1609 b->CreateSelect(b->CreateICmpSLT(lhs, zero), minus_one, zero);
1610 } else {
1611 saturated_value = zero;
1612 }
1613 llvm::Value* shift_amt_in_range =
1614 b->CreateICmpULT(rhs, integer_bitsize_constant, "shft.chk");
1615 return b->CreateSelect(shift_amt_in_range, shift_result, saturated_value);
1616 }
1617
GetOne(llvm::Type * type)1618 llvm::Value* ElementalIrEmitter::GetOne(llvm::Type* type) {
1619 return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 1);
1620 }
1621
GetZero(llvm::Type * type)1622 llvm::Value* ElementalIrEmitter::GetZero(llvm::Type* type) {
1623 return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 0);
1624 }
1625
GetIntSMin(llvm::Type * type)1626 llvm::Value* ElementalIrEmitter::GetIntSMin(llvm::Type* type) {
1627 auto* integer_type = llvm::cast<llvm::IntegerType>(type);
1628 return llvm::ConstantInt::get(integer_type, llvm::APInt::getSignedMinValue(
1629 integer_type->getBitWidth()));
1630 }
1631
GetMinusOne(llvm::Type * type)1632 llvm::Value* ElementalIrEmitter::GetMinusOne(llvm::Type* type) {
1633 auto* integer_type = llvm::cast<llvm::IntegerType>(type);
1634 return llvm::ConstantInt::get(
1635 integer_type, llvm::APInt::getAllOnesValue(integer_type->getBitWidth()));
1636 }
1637
IsZero(llvm::Value * v)1638 llvm::Value* ElementalIrEmitter::IsZero(llvm::Value* v) {
1639 return ICmpEQ(v, llvm::ConstantInt::get(v->getType(), 0));
1640 }
1641
IsIntMinDivisionOverflow(llvm::Value * lhs,llvm::Value * rhs)1642 llvm::Value* ElementalIrEmitter::IsIntMinDivisionOverflow(llvm::Value* lhs,
1643 llvm::Value* rhs) {
1644 return And(ICmpEQ(lhs, GetIntSMin(lhs->getType())),
1645 ICmpEQ(rhs, GetMinusOne(rhs->getType())));
1646 }
1647
EmitIntegerDivide(llvm::Value * lhs,llvm::Value * rhs,bool is_signed)1648 llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs,
1649 llvm::Value* rhs,
1650 bool is_signed) {
1651 // Integer division overflow behavior:
1652 //
1653 // X / 0 == -1
1654 // INT_SMIN /s -1 = INT_SMIN
1655
1656 if (!is_signed) {
1657 llvm::Value* udiv_is_unsafe = IsZero(rhs);
1658 llvm::Value* safe_rhs = Select(udiv_is_unsafe, GetOne(lhs->getType()), rhs);
1659 llvm::Value* safe_div = UDiv(lhs, safe_rhs);
1660 return Select(udiv_is_unsafe, GetMinusOne(lhs->getType()), safe_div);
1661 }
1662
1663 llvm::Value* has_zero_divisor = IsZero(rhs);
1664 llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
1665 llvm::Value* sdiv_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
1666 llvm::Value* safe_rhs = Select(sdiv_is_unsafe, GetOne(lhs->getType()), rhs);
1667 llvm::Value* safe_div = SDiv(lhs, safe_rhs);
1668
1669 return Select(
1670 has_zero_divisor, GetMinusOne(lhs->getType()),
1671 Select(has_int_min_overflow, GetIntSMin(lhs->getType()), safe_div));
1672 }
1673
EmitIntegerRemainder(llvm::Value * lhs,llvm::Value * rhs,bool is_signed)1674 llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs,
1675 llvm::Value* rhs,
1676 bool is_signed) {
1677 // Integer remainder overflow behavior:
1678 //
1679 // X % 0 == X
1680 // INT_SMIN %s -1 = 0
1681
1682 if (!is_signed) {
1683 llvm::Value* urem_is_unsafe = IsZero(rhs);
1684 llvm::Value* safe_rhs = Select(urem_is_unsafe, GetOne(lhs->getType()), rhs);
1685 llvm::Value* safe_rem = URem(lhs, safe_rhs);
1686 return Select(urem_is_unsafe, lhs, safe_rem);
1687 }
1688
1689 llvm::Value* has_zero_divisor = IsZero(rhs);
1690 llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
1691 llvm::Value* srem_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
1692 llvm::Value* safe_rhs = Select(srem_is_unsafe, GetOne(lhs->getType()), rhs);
1693 llvm::Value* safe_rem = SRem(lhs, safe_rhs);
1694
1695 return Select(
1696 has_zero_divisor, lhs,
1697 Select(has_int_min_overflow, GetZero(lhs->getType()), safe_rem));
1698 }
1699
EmitIntegerPow(llvm::Value * base,llvm::Value * exponent,bool is_signed)1700 llvm::Value* ElementalIrEmitter::EmitIntegerPow(llvm::Value* base,
1701 llvm::Value* exponent,
1702 bool is_signed) {
1703 // Exponentiation by squaring:
1704 // https://en.wikipedia.org/wiki/Exponentiation_by_squaring;
1705 int bits = 6; // Everything else would overflow for any exponent > 1, as 2^64
1706 // is the larget possible exponent for a 64-bit integer, and
1707 // that's 1 << 6.
1708 llvm::Value* accumulator = llvm::ConstantInt::get(base->getType(), 1);
1709 llvm::Value* one = llvm::ConstantInt::get(exponent->getType(), 1);
1710 llvm::Value* zero = llvm::ConstantInt::get(exponent->getType(), 0);
1711 llvm::Value* original_base = base;
1712 llvm::Value* original_exponent = exponent;
1713
1714 // Unroll the loop at compile time.
1715 for (int i = 0; i < bits; i++) {
1716 accumulator =
1717 b_->CreateSelect(b_->CreateICmpEQ(b_->CreateAnd(exponent, one), one),
1718 b_->CreateMul(accumulator, base), accumulator);
1719 base = b_->CreateMul(base, base);
1720 exponent = b_->CreateLShr(exponent, 1);
1721 }
1722 return b_->CreateSelect(
1723 b_->CreateICmpSGE(original_exponent, zero), accumulator,
1724 b_->CreateSelect(b_->CreateICmpEQ(original_base, one), one, zero));
1725 }
1726
EmitPredBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1727 StatusOr<llvm::Value*> ElementalIrEmitter::EmitPredBinaryOp(
1728 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1729 // Per the reference interpreter, pred arithmetic should behave like
1730 // `int8_t(x) OP int8_t(y) != 0`. For most permitted ops, we can just emit
1731 // the underlying i8 op to achieve this (e.g. kAnd, kOr, kXor, kMultiply). In
1732 // the case of kAdd, we would need to insert a comparison instruction after
1733 // the addition, but it's both easier and faster to emit a bitwise or
1734 // instruction instead.
1735 //
1736 // For several of these ops, a faster bitwise implementation is available, but
1737 // LLVM is unlikely to be able to see it, since it gets IR that e.g. loads i8s
1738 // from memory, multiplies them, and writes the result back, without any
1739 // indication that the inputs were assumed to be 0 or 1. So, just in case,
1740 // help it out by choosing the faster instruction to begin with.
1741 switch (op->opcode()) {
1742 case HloOpcode::kCompare:
1743 case HloOpcode::kXor:
1744 return EmitIntegerBinaryOp(op, lhs_value, rhs_value, false);
1745
1746 // zext(i1 x) + zext(i1 y) != 0 === or(x, y)
1747 // max(zext(i1 x), zext(i1 y)) != 0 === or(x, y)
1748 case HloOpcode::kAdd:
1749 case HloOpcode::kMaximum:
1750 case HloOpcode::kOr:
1751 return Or(lhs_value, rhs_value);
1752
1753 // zext(i1 x) * zext(i1 y) != 0 === and(x, y)
1754 // min(zext(i1 x), zext(i1 y)) != 0 === and(x, y)
1755 case HloOpcode::kMultiply:
1756 case HloOpcode::kMinimum:
1757 case HloOpcode::kAnd:
1758 return And(lhs_value, rhs_value);
1759
1760 // These opcodes are rejected by shape-inference for PRED elements; calling
1761 // them out here serves more as documentation than a necessary check.
1762 case HloOpcode::kDivide:
1763 case HloOpcode::kRemainder:
1764 case HloOpcode::kPower:
1765 case HloOpcode::kSubtract:
1766 case HloOpcode::kShiftLeft:
1767 case HloOpcode::kShiftRightArithmetic:
1768 case HloOpcode::kShiftRightLogical:
1769 return InternalError("Invalid binary op '%s' for pred",
1770 HloOpcodeString(op->opcode()));
1771
1772 default:
1773 return Unimplemented("binary pred op '%s'",
1774 HloOpcodeString(op->opcode()));
1775 }
1776 }
1777
EmitIntegerBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1778 StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
1779 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value,
1780 bool is_signed) {
1781 switch (op->opcode()) {
1782 // TODO(jingyue): add the "nsw" attribute for signed types.
1783 case HloOpcode::kAdd:
1784 return Add(lhs_value, rhs_value);
1785 case HloOpcode::kSubtract:
1786 return Sub(lhs_value, rhs_value);
1787 case HloOpcode::kMultiply:
1788 return Mul(lhs_value, rhs_value);
1789 case HloOpcode::kDivide:
1790 return EmitIntegerDivide(lhs_value, rhs_value, is_signed);
1791 case HloOpcode::kRemainder:
1792 return EmitIntegerRemainder(lhs_value, rhs_value, is_signed);
1793 case HloOpcode::kCompare: {
1794 switch (op->comparison_direction()) {
1795 case ComparisonDirection::kEq:
1796 return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
1797 rhs_value, b_);
1798 case ComparisonDirection::kNe:
1799 return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value,
1800 rhs_value, b_);
1801 case ComparisonDirection::kLt:
1802 return llvm_ir::EmitComparison(
1803 is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT,
1804 lhs_value, rhs_value, b_);
1805 case ComparisonDirection::kGt:
1806 return llvm_ir::EmitComparison(
1807 is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT,
1808 lhs_value, rhs_value, b_);
1809 case ComparisonDirection::kLe:
1810 return llvm_ir::EmitComparison(
1811 is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE,
1812 lhs_value, rhs_value, b_);
1813 case ComparisonDirection::kGe:
1814 return llvm_ir::EmitComparison(
1815 is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
1816 lhs_value, rhs_value, b_);
1817 }
1818 }
1819 case HloOpcode::kMinimum:
1820 return EmitIntegralMin(lhs_value, rhs_value, is_signed);
1821 case HloOpcode::kMaximum:
1822 return EmitIntegralMax(lhs_value, rhs_value, is_signed);
1823 case HloOpcode::kAnd:
1824 return And(lhs_value, rhs_value);
1825 case HloOpcode::kOr:
1826 return Or(lhs_value, rhs_value);
1827 case HloOpcode::kPower:
1828 return EmitIntegerPow(lhs_value, rhs_value, is_signed);
1829 case HloOpcode::kXor:
1830 return Xor(lhs_value, rhs_value);
1831
1832 // Shifting out bits >= the number of bits in the type being shifted
1833 // produces a poison value in LLVM which is basically "deferred undefined
1834 // behavior" -- doing something observable with such a value precipitates
1835 // UB. We replace the poison value with a constant to avoid this deferred
1836 // UB.
1837 case HloOpcode::kShiftRightArithmetic:
1838 return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1839 AShr(lhs_value, rhs_value),
1840 /*saturate_to_sign_bit=*/true);
1841 case HloOpcode::kShiftLeft:
1842 return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1843 Shl(lhs_value, rhs_value),
1844 /*saturate_to_sign_bit=*/false);
1845 case HloOpcode::kShiftRightLogical:
1846 return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1847 LShr(lhs_value, rhs_value),
1848 /*saturate_to_sign_bit=*/false);
1849 default:
1850 return Unimplemented("binary integer op '%s'",
1851 HloOpcodeString(op->opcode()));
1852 }
1853 }
1854
EmitIntegralMax(llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1855 llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value,
1856 llvm::Value* rhs_value,
1857 bool is_signed) {
1858 return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE
1859 : llvm::ICmpInst::ICMP_UGE,
1860 lhs_value, rhs_value),
1861 lhs_value, rhs_value);
1862 }
1863
EmitIntegralMin(llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1864 llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value,
1865 llvm::Value* rhs_value,
1866 bool is_signed) {
1867 return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE
1868 : llvm::ICmpInst::ICMP_ULE,
1869 lhs_value, rhs_value),
1870 lhs_value, rhs_value);
1871 }
1872
EmitElementalSelect(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1873 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalSelect(
1874 const HloInstruction* hlo,
1875 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1876 const llvm_ir::IrArray::Index& index) {
1877 TF_ASSIGN_OR_RETURN(llvm::Value * pred_value,
1878 operand_to_generator.at(hlo->operand(0))(index));
1879 TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value,
1880 operand_to_generator.at(hlo->operand(1))(index));
1881 TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value,
1882 operand_to_generator.at(hlo->operand(2))(index));
1883 return Select(Trunc(pred_value, b_->getInt1Ty()), on_true_value,
1884 on_false_value);
1885 }
1886
EmitElementalClamp(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1887 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalClamp(
1888 const HloInstruction* hlo,
1889 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1890 const llvm_ir::IrArray::Index& index) {
1891 TF_ASSIGN_OR_RETURN(llvm::Value * min_value,
1892 operand_to_generator.at(hlo->operand(0))(index));
1893 TF_ASSIGN_OR_RETURN(llvm::Value * arg_value,
1894 operand_to_generator.at(hlo->operand(1))(index));
1895 TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
1896 operand_to_generator.at(hlo->operand(2))(index));
1897 PrimitiveType prim_type = hlo->shape().element_type();
1898 if (primitive_util::IsFloatingPointType(prim_type)) {
1899 return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value, ""), "");
1900 } else if (primitive_util::IsIntegralType(prim_type)) {
1901 bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
1902 return EmitIntegralMin(
1903 max_value, EmitIntegralMax(min_value, arg_value, is_signed), is_signed);
1904 } else {
1905 return Unimplemented("Clamp unimplemented for %s",
1906 PrimitiveType_Name(prim_type));
1907 }
1908 }
1909
EmitElementalConcatenate(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & source_index)1910 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
1911 const HloInstruction* hlo,
1912 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1913 const llvm_ir::IrArray::Index& source_index) {
1914 const int64_t concat_dim = hlo->dimensions(0);
1915 llvm::BasicBlock* init_block = b_->GetInsertBlock();
1916
1917 llvm::BasicBlock* exit_block;
1918 if (b_->GetInsertPoint() != init_block->end()) {
1919 // Inserting into the middle.
1920 CHECK(init_block->getTerminator());
1921 exit_block =
1922 init_block->splitBasicBlock(b_->GetInsertPoint(), IrName(hlo, "merge"));
1923 init_block->getTerminator()->eraseFromParent();
1924 } else {
1925 // Inserting at the end.
1926 CHECK(!init_block->getTerminator());
1927 exit_block = llvm_ir::CreateBasicBlock(
1928 /*insert_before=*/nullptr, IrName(hlo, "merge"), b_);
1929 }
1930
1931 llvm_ir::SetToFirstInsertPoint(exit_block, b_);
1932 llvm::PHINode* output = b_->CreatePHI(
1933 llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
1934 hlo->operands().size());
1935 auto prior_insert_point = b_->GetInsertPoint();
1936
1937 b_->SetInsertPoint(init_block);
1938
1939 // Assign a unique id for each *different* operand, and count how often each
1940 // operand is used. If all operands are different, the usage count will be 1
1941 // for each operand.
1942 absl::flat_hash_map<const HloInstruction*, int64_t> to_unique_operand_id;
1943 std::vector<int64_t> operand_usage_count;
1944 for (const HloInstruction* operand : hlo->operands()) {
1945 if (to_unique_operand_id.contains(operand)) {
1946 ++operand_usage_count[to_unique_operand_id[operand]];
1947 } else {
1948 int64_t unique_operand_id = to_unique_operand_id.size();
1949 to_unique_operand_id[operand] = unique_operand_id;
1950 operand_usage_count.push_back(1);
1951 }
1952 }
1953
1954 // To avoid that we emit the same operand more than once, we create one basic
1955 // block for each *different* operand with a PHI node for the different source
1956 // index inputs.
1957 std::vector<llvm::BasicBlock*> emit_operand_blocks(
1958 to_unique_operand_id.size(), nullptr);
1959 std::vector<llvm::PHINode*> source_index_phis(to_unique_operand_id.size(),
1960 nullptr);
1961 for (const HloInstruction* operand : hlo->operands()) {
1962 int64_t operand_id = to_unique_operand_id[operand];
1963 if (emit_operand_blocks[operand_id] != nullptr) {
1964 continue;
1965 }
1966
1967 emit_operand_blocks[operand_id] = llvm_ir::CreateBasicBlock(
1968 exit_block, StrCat("concat_index_from_operand_id", operand_id), b_);
1969 auto saved_insert_point = b_->GetInsertPoint();
1970 llvm_ir::SetToFirstInsertPoint(emit_operand_blocks[operand_id], b_);
1971 source_index_phis[operand_id] =
1972 b_->CreatePHI(source_index.GetType(), operand_usage_count[operand_id]);
1973 std::vector<llvm::Value*> operand_multi_index = source_index.multidim();
1974 operand_multi_index[concat_dim] = b_->CreateNSWSub(
1975 operand_multi_index[concat_dim], source_index_phis[operand_id]);
1976
1977 // Create the terminator of the block before calling operand generators,
1978 // because they require non-degenerate basic blocks.
1979 b_->SetInsertPoint(llvm::BranchInst::Create(
1980 exit_block, /*InsertAtEnd=*/emit_operand_blocks[operand_id]));
1981 llvm_ir::IrArray::Index operand_index(operand_multi_index, operand->shape(),
1982 source_index.GetType());
1983
1984 TF_ASSIGN_OR_RETURN(llvm::Value * value,
1985 operand_to_generator.at(operand)(operand_index));
1986 output->addIncoming(value, b_->GetInsertBlock());
1987 b_->SetInsertPoint(init_block, saved_insert_point);
1988 }
1989
1990 // We use bisection to select the input operand.
1991 int64_t current_offset = 0;
1992
1993 // Offset for every operand.
1994 std::vector<std::pair<int64_t, const HloInstruction*>> cases;
1995
1996 cases.reserve(hlo->operand_count());
1997 for (const HloInstruction* operand : hlo->operands()) {
1998 cases.emplace_back(current_offset, operand);
1999 current_offset += operand->shape().dimensions(concat_dim);
2000 }
2001 CHECK_EQ(current_offset, hlo->shape().dimensions(concat_dim));
2002
2003 std::function<llvm::BasicBlock*(
2004 absl::Span<const std::pair<int64_t, const HloInstruction*>> operands)>
2005 emit_tree =
2006 [&](absl::Span<const std::pair<int64_t, const HloInstruction*>>
2007 operands) {
2008 llvm::IRBuilder<>::InsertPointGuard guard(*b_);
2009 size_t mid = operands.size() / 2;
2010 const std::pair<int64_t, const HloInstruction*>& pivot =
2011 operands[mid];
2012 llvm::BasicBlock* block = llvm_ir::CreateBasicBlock(
2013 exit_block,
2014 absl::StrCat("concatenate.pivot.", pivot.first, "."), b_);
2015 b_->SetInsertPoint(block);
2016
2017 // If there's only one element we're done. The range is contiguous
2018 // so we can just jump to the block for it.
2019 if (operands.size() == 1) {
2020 const std::pair<int64_t, const HloInstruction*>& operand =
2021 operands.back();
2022 int64_t operand_id = to_unique_operand_id[operand.second];
2023
2024 source_index_phis[operand_id]->addIncoming(
2025 source_index.GetConstantWithIndexType(operand.first),
2026 b_->GetInsertBlock());
2027 b_->CreateBr(emit_operand_blocks[operand_id]);
2028 return block;
2029 }
2030
2031 // Take the middle element and recurse.
2032 llvm::Constant* pivot_const = llvm::ConstantInt::get(
2033 source_index[concat_dim]->getType(), pivot.first);
2034 llvm::Value* comp =
2035 b_->CreateICmpULT(source_index[concat_dim], pivot_const);
2036
2037 llvm::BasicBlock* left_block = emit_tree(operands.subspan(0, mid));
2038 llvm::BasicBlock* right_block = emit_tree(operands.subspan(mid));
2039
2040 b_->CreateCondBr(comp, left_block, right_block);
2041 return block;
2042 };
2043
2044 Br(emit_tree(cases));
2045
2046 b_->SetInsertPoint(exit_block, prior_insert_point);
2047 return output;
2048 }
2049
EmitElementalDynamicSlice(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)2050 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
2051 const HloInstruction* hlo,
2052 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2053 const llvm_ir::IrArray::Index& index) {
2054 // Emit IR to read dynamic start indices from hlo->operand(1).
2055 const HloInstruction* input_hlo = hlo->operand(0);
2056 const int64_t rank = input_hlo->shape().rank();
2057 // Use the same index type for all tensor accesses in the same kernel.
2058 llvm::Type* index_type = index.GetType();
2059 std::vector<llvm::Value*> slice_start_multi_index(rank);
2060 for (int64_t i = 0; i < rank; ++i) {
2061 auto index_typed_const = [&](uint64_t c) -> llvm::Constant* {
2062 return llvm::ConstantInt::get(index_type, c);
2063 };
2064 llvm_ir::IrArray::Index zero_index(index_type);
2065 TF_ASSIGN_OR_RETURN(
2066 llvm::Value * start_index_value,
2067 operand_to_generator.at(hlo->operand(1 + i))(zero_index));
2068
2069 // Clamp the start index so that the sliced portion fits in the operand:
2070 // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size)
2071 start_index_value = SExtOrTrunc(start_index_value, index_type);
2072 int64_t largest_valid_start_index =
2073 input_hlo->shape().dimensions(i) - hlo->shape().dimensions(i);
2074 CHECK_GE(largest_valid_start_index, 0);
2075
2076 bool is_signed = ShapeUtil::ElementIsSigned(hlo->operand(1)->shape());
2077 start_index_value = EmitIntegralMin(
2078 index_typed_const(largest_valid_start_index),
2079 EmitIntegralMax(index_typed_const(0), start_index_value, is_signed),
2080 is_signed);
2081
2082 start_index_value->setName(IrName(hlo, StrCat("start_idx", i)));
2083 slice_start_multi_index[i] = start_index_value;
2084 }
2085
2086 std::vector<llvm::Value*> input_multi_index(rank);
2087 for (int64_t i = 0; i < rank; ++i) {
2088 // Emit IR which computes:
2089 // input_index = start_index + offset_index
2090 input_multi_index[i] = Add(slice_start_multi_index[i], index[i]);
2091 }
2092 llvm_ir::IrArray::Index input_index(input_multi_index, input_hlo->shape(),
2093 index_type);
2094 return operand_to_generator.at(input_hlo)(input_index);
2095 }
2096
EmitElementalGather(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)2097 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
2098 const HloInstruction* hlo,
2099 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2100 const llvm_ir::IrArray::Index& index) {
2101 const Shape& operand_shape = hlo->operand(0)->shape();
2102 const Shape& indices_shape = hlo->operand(1)->shape();
2103 const Shape& output_shape = hlo->shape();
2104
2105 const GatherDimensionNumbers& dim_numbers = hlo->gather_dimension_numbers();
2106
2107 const llvm_ir::ElementGenerator& operand_generator =
2108 operand_to_generator.at(hlo->operand(0));
2109 const llvm_ir::ElementGenerator& indices_generator =
2110 operand_to_generator.at(hlo->operand(1));
2111
2112 llvm::Type* index_type = index.GetType();
2113 // This is the index into `operand` that holds the element we want to
2114 // generate.
2115 std::vector<llvm::Value*> operand_multi_index;
2116
2117 // First copy in the window indices to operand_index. Also collect a mapping
2118 // from operand dimension to output window dimension. Elided window dimensions
2119 // map to -1.
2120 std::vector<int64_t> operand_to_output_dim(operand_shape.dimensions_size(),
2121 -1);
2122 for (int64_t i = 0, e = operand_shape.dimensions_size(),
2123 operand_index_dim = 0;
2124 i < e; i++) {
2125 if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
2126 operand_multi_index.push_back(index.GetConstantWithIndexType(0));
2127 } else {
2128 int64_t output_window_dim = dim_numbers.offset_dims(operand_index_dim++);
2129 operand_to_output_dim[i] = output_window_dim;
2130 operand_multi_index.push_back(index[output_window_dim]);
2131 }
2132 }
2133
2134 // This is the index of the index vector in the start_indices tensor.
2135 std::vector<llvm::Value*> gather_index_index_components;
2136 {
2137 for (int64_t i = 0, e = output_shape.dimensions_size(); i < e; i++) {
2138 if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) {
2139 gather_index_index_components.push_back(index[i]);
2140 }
2141 }
2142
2143 if (gather_index_index_components.size() !=
2144 indices_shape.dimensions_size()) {
2145 gather_index_index_components.insert(
2146 gather_index_index_components.begin() +
2147 dim_numbers.index_vector_dim(),
2148 nullptr);
2149 }
2150 }
2151
2152 auto add_to_operand_index = [&](llvm::Value* index_component, int64_t dim) {
2153 auto index_component_type = index_component->getType();
2154 auto extended_type = index_component_type->getScalarSizeInBits() >=
2155 index_type->getScalarSizeInBits()
2156 ? index_component_type
2157 : index_type;
2158 // Possibly extend the value at the beginning to ensure clamping logic stays
2159 // in bounds.
2160 auto maybe_extended_index =
2161 index_component_type != extended_type
2162 ? b_->CreateSExt(index_component, extended_type)
2163 : index_component;
2164 int64_t operand_dim = dim_numbers.start_index_map(dim);
2165 int64_t output_dim = operand_to_output_dim[operand_dim];
2166 // If 'output_dim' is -1, it means 'operand_dim' is an elided window dim.
2167 // This means we set the iteration index to 0, so for the purpose of the
2168 // following calculations we can consider the output dimension size to be 1.
2169 int64_t output_dim_size =
2170 output_dim == -1 ? 1 : output_shape.dimensions(output_dim);
2171 int64_t largest_valid_start_index =
2172 operand_shape.dimensions(operand_dim) - output_dim_size;
2173 CHECK_GE(largest_valid_start_index, 0);
2174
2175 // Clamp the gather index so that the gather region fits in the operand.
2176 // clamped_index =
2177 // clamp(gather_dim_component_extended, 0, largest_valid_start_index);
2178 bool is_signed = ShapeUtil::ElementIsSigned(indices_shape);
2179 auto clamped_index = EmitIntegralMin(
2180 llvm::ConstantInt::get(extended_type, largest_valid_start_index),
2181 EmitIntegralMax(llvm::ConstantInt::get(extended_type, 0),
2182 maybe_extended_index, is_signed),
2183 is_signed);
2184 // Truncate at the end to the optimized index size
2185 auto maybe_truncated_clamped_index = extended_type != index_type
2186 ? Trunc(clamped_index, index_type)
2187 : clamped_index;
2188
2189 operand_multi_index[operand_dim] =
2190 Add(operand_multi_index[operand_dim], maybe_truncated_clamped_index);
2191 };
2192
2193 if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) {
2194 IrArray::Index gather_index_index(gather_index_index_components,
2195 indices_shape, index_type);
2196 TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
2197 indices_generator(gather_index_index));
2198 add_to_operand_index(gather_dim_component, 0);
2199 } else {
2200 int64_t index_vector_size =
2201 indices_shape.dimensions(dim_numbers.index_vector_dim());
2202 for (int64_t i = 0; i < index_vector_size; i++) {
2203 gather_index_index_components[dim_numbers.index_vector_dim()] =
2204 index.GetConstantWithIndexType(i);
2205 IrArray::Index gather_index_index(gather_index_index_components,
2206 indices_shape, index_type);
2207 TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
2208 indices_generator(gather_index_index));
2209 add_to_operand_index(gather_dim_component, i);
2210 }
2211 }
2212 IrArray::Index operand_index(operand_multi_index, operand_shape, index_type);
2213 return operand_generator(operand_index);
2214 }
2215
EmitElementalDynamicUpdateSlice(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)2216 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
2217 const HloInstruction* hlo,
2218 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2219 const llvm_ir::IrArray::Index& index) {
2220 const HloInstruction* input_hlo = hlo->operand(0);
2221 const HloInstruction* update_hlo = hlo->operand(1);
2222 const HloInstruction* start_hlo = hlo->operand(2);
2223 // Calculate slice start/end indices.
2224 const int64_t rank = input_hlo->shape().rank();
2225 std::vector<llvm::Value*> slice_start_multi_index(rank);
2226 std::vector<llvm::Value*> slice_limit_multi_index(rank);
2227 // Slice intersection gathers (ANDs) conditions on all ranks for which
2228 // 'input' is set to 'update'
2229 llvm::Value* slice_intersection = b_->getTrue();
2230
2231 for (int64_t i = 0; i < rank; ++i) {
2232 llvm::Type* index_type = index[0]->getType();
2233 auto index_typed_const = [&](uint64_t c) -> llvm::Constant* {
2234 return llvm::ConstantInt::get(index_type, c);
2235 };
2236
2237 llvm_ir::IrArray::Index zero_index(index_type);
2238 TF_ASSIGN_OR_RETURN(
2239 llvm::Value * start_index_value,
2240 operand_to_generator.at(hlo->operand(2 + i))(zero_index));
2241
2242 // Clamp the start index so that the update region fits in the operand.
2243 // start_index = clamp(start_index, 0, input_dim_size - update_dim_size)
2244 start_index_value = SExtOrTrunc(start_index_value, index_type);
2245 llvm::Value* update_dim_size =
2246 index_typed_const(update_hlo->shape().dimensions(i));
2247 int64_t largest_valid_start_index =
2248 input_hlo->shape().dimensions(i) - update_hlo->shape().dimensions(i);
2249 CHECK_GE(largest_valid_start_index, 0);
2250
2251 bool is_signed = ShapeUtil::ElementIsSigned(start_hlo->shape());
2252 start_index_value = EmitIntegralMin(
2253 index_typed_const(largest_valid_start_index),
2254 EmitIntegralMax(index_typed_const(0), start_index_value, is_signed),
2255 is_signed);
2256
2257 start_index_value->setName(IrName(hlo, StrCat("start_idx", i)));
2258 slice_start_multi_index[i] = start_index_value;
2259 slice_limit_multi_index[i] =
2260 Add(slice_start_multi_index[i], update_dim_size);
2261
2262 slice_intersection =
2263 And(slice_intersection, ICmpSGE(index[i], slice_start_multi_index[i]),
2264 "slice_intersection");
2265 slice_intersection =
2266 And(slice_intersection, ICmpSLT(index[i], slice_limit_multi_index[i]),
2267 "slice_intersection");
2268 }
2269
2270 // Emit:
2271 // if (slice_intersection) -> return data from 'update'.
2272 // else -> return data from 'input'.
2273 llvm::AllocaInst* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
2274 llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
2275 "ret_value_addr", b_);
2276 llvm_ir::LlvmIfData if_data =
2277 llvm_ir::EmitIfThenElse(slice_intersection, "slice_intersection", b_);
2278
2279 // Handle true BB (return data from 'update')
2280 SetToFirstInsertPoint(if_data.true_block, b_);
2281 // Compute update index for intersection case.
2282 std::vector<llvm::Value*> update_multi_index(rank);
2283 for (int64_t i = 0; i < rank; ++i) {
2284 update_multi_index[i] = Sub(index[i], slice_start_multi_index[i]);
2285 }
2286 llvm_ir::IrArray::Index update_index(update_multi_index, update_hlo->shape(),
2287 index.GetType());
2288 TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
2289 operand_to_generator.at(update_hlo)(update_index));
2290 Store(true_value, ret_value_addr);
2291
2292 // Handle false BB (return data from 'input')
2293 SetToFirstInsertPoint(if_data.false_block, b_);
2294 TF_ASSIGN_OR_RETURN(llvm::Value * false_value,
2295 operand_to_generator.at(input_hlo)(index));
2296 Store(false_value, ret_value_addr);
2297
2298 SetToFirstInsertPoint(if_data.after_block, b_);
2299 return Load(ret_value_addr->getAllocatedType(), ret_value_addr);
2300 }
2301
EmitElementalPad(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & padded_index)2302 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
2303 const HloInstruction* hlo,
2304 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2305 const llvm_ir::IrArray::Index& padded_index) {
2306 std::vector<llvm::Value*> multi_index = padded_index.multidim();
2307 llvm::Value* in_bounds = b_->getTrue();
2308 for (size_t i = 0; i < multi_index.size(); ++i) {
2309 auto index_typed_const = [=](int64_t n) {
2310 return padded_index.GetConstantWithIndexType(n);
2311 };
2312 const auto& pad_dim = hlo->padding_config().dimensions(i);
2313 multi_index[i] =
2314 Sub(multi_index[i], index_typed_const(pad_dim.edge_padding_low()));
2315 in_bounds = And(in_bounds, ICmpSGE(multi_index[i], index_typed_const(0)),
2316 "in_bounds");
2317 in_bounds =
2318 And(in_bounds,
2319 ICmpEQ(index_typed_const(0),
2320 URem(multi_index[i],
2321 index_typed_const(pad_dim.interior_padding() + 1))),
2322 "in_bounds");
2323 multi_index[i] =
2324 SDiv(multi_index[i], index_typed_const(pad_dim.interior_padding() + 1));
2325 in_bounds =
2326 And(in_bounds,
2327 ICmpSLT(multi_index[i],
2328 index_typed_const(hlo->operand(0)->shape().dimensions(i))),
2329 "in_bounds");
2330 }
2331
2332 // if (in_bounds) {
2333 // ret_value = operand0[index]; // source
2334 // } else {
2335 // ret_value = *operand1; // padding
2336 // }
2337 llvm::AllocaInst* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
2338 llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
2339 "pad_result_addr", b_);
2340 llvm_ir::LlvmIfData if_data =
2341 llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
2342 SetToFirstInsertPoint(if_data.true_block, b_);
2343 llvm_ir::IrArray::Index index(multi_index, hlo->operand(0)->shape(),
2344 padded_index.GetType());
2345 TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2346 operand_to_generator.at(hlo->operand(0))(index));
2347 Store(operand_value, ret_value_addr);
2348
2349 SetToFirstInsertPoint(if_data.false_block, b_);
2350 TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
2351 operand_to_generator.at(hlo->operand(1))(
2352 IrArray::Index(index.GetType())));
2353 Store(padding_value, ret_value_addr);
2354
2355 SetToFirstInsertPoint(if_data.after_block, b_);
2356 // Don't create phi(operand_value, padding_value) here, because invoking
2357 // operand_to_generator may create new basic blocks, making the parent
2358 // of operand_value or padding_value no longer a predecessor of
2359 // if_data.after_block.
2360 return Load(ret_value_addr->getAllocatedType(), ret_value_addr);
2361 }
2362
EmitElementalDot(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & dot_result_index)2363 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
2364 const HloInstruction* hlo,
2365 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2366 const llvm_ir::IrArray::Index& dot_result_index) {
2367 auto lhs_generator = operand_to_generator.at(hlo->operand(0));
2368 auto rhs_generator = operand_to_generator.at(hlo->operand(1));
2369
2370 const DotDimensionNumbers& dim_numbers = hlo->dot_dimension_numbers();
2371 int64_t lhs_contracting_dim = dim_numbers.lhs_contracting_dimensions(0);
2372 int64_t rhs_contracting_dim = dim_numbers.rhs_contracting_dimensions(0);
2373
2374 int64_t contracted_dim_size =
2375 hlo->operand(0)->shape().dimensions(lhs_contracting_dim);
2376 int64_t lhs_dims = hlo->operand(0)->shape().dimensions_size();
2377 int64_t rhs_dims = hlo->operand(1)->shape().dimensions_size();
2378
2379 llvm::Type* index_type = dot_result_index.GetType();
2380 auto index_typed_const = [&](uint64_t c) -> llvm::Constant* {
2381 return llvm::ConstantInt::get(index_type, c);
2382 };
2383
2384 std::unique_ptr<llvm_ir::ForLoop> inner_loop = llvm_ir::ForLoop::EmitForLoop(
2385 IrName(hlo, "inner"), index_typed_const(0),
2386 index_typed_const(contracted_dim_size), index_typed_const(1), b_);
2387
2388 SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), b_);
2389 PrimitiveType primitive_type = hlo->shape().element_type();
2390 llvm::Type* primitive_type_llvm =
2391 llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
2392 llvm::AllocaInst* accumulator_alloca =
2393 llvm_ir::EmitAllocaAtFunctionEntry(primitive_type_llvm, "dot_acc", b_);
2394 Store(llvm::Constant::getNullValue(primitive_type_llvm), accumulator_alloca);
2395
2396 SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), b_);
2397
2398 // This is the inner reduction loop for a dot operation that produces
2399 // one element in the output. If the operands to the dot operation have
2400 // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E].
2401 // Given an output index [a,b,c,d,e] in the result, we compute:
2402 // sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T))
2403
2404 std::vector<llvm::Value*> lhs_multi_index, rhs_multi_index;
2405 for (int64_t i = 0; i < lhs_dims - 1; i++) {
2406 lhs_multi_index.push_back(dot_result_index[i]);
2407 }
2408 lhs_multi_index.insert(lhs_multi_index.begin() + lhs_contracting_dim,
2409 inner_loop->GetIndVarValue());
2410 IrArray::Index lhs_index(lhs_multi_index, hlo->operand(0)->shape(),
2411 index_type);
2412
2413 int64_t num_batch_dims = dim_numbers.rhs_batch_dimensions_size();
2414 for (int64_t i = 0; i < num_batch_dims; i++) {
2415 rhs_multi_index.push_back(
2416 dot_result_index[dim_numbers.rhs_batch_dimensions(i)]);
2417 }
2418 for (int64_t i = 0; i < rhs_dims - 1 - num_batch_dims; i++) {
2419 rhs_multi_index.push_back(dot_result_index[lhs_dims - 1 + i]);
2420 }
2421 rhs_multi_index.insert(rhs_multi_index.begin() + rhs_contracting_dim,
2422 inner_loop->GetIndVarValue());
2423 IrArray::Index rhs_index(rhs_multi_index, hlo->operand(1)->shape(),
2424 index_type);
2425
2426 llvm::Value* current_accumulator =
2427 Load(accumulator_alloca->getAllocatedType(), accumulator_alloca);
2428 TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
2429 TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
2430 llvm::Value* next_accumulator =
2431 EmitMulAdd(lhs_value, rhs_value, current_accumulator, primitive_type);
2432 Store(next_accumulator, accumulator_alloca);
2433
2434 SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_);
2435 return Load(accumulator_alloca->getAllocatedType(), accumulator_alloca);
2436 }
2437
MakeElementGenerator(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator)2438 llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
2439 const HloInstruction* hlo,
2440 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) {
2441 switch (hlo->opcode()) {
2442 case HloOpcode::kAbs:
2443 case HloOpcode::kRoundNearestAfz:
2444 case HloOpcode::kRoundNearestEven:
2445 case HloOpcode::kCeil:
2446 case HloOpcode::kClz:
2447 case HloOpcode::kConvert:
2448 case HloOpcode::kBitcastConvert:
2449 case HloOpcode::kCos:
2450 case HloOpcode::kExp:
2451 case HloOpcode::kExpm1:
2452 case HloOpcode::kFloor:
2453 case HloOpcode::kImag:
2454 case HloOpcode::kIsFinite:
2455 case HloOpcode::kLog:
2456 case HloOpcode::kLog1p:
2457 case HloOpcode::kNegate:
2458 case HloOpcode::kNot:
2459 case HloOpcode::kPopulationCount:
2460 case HloOpcode::kReal:
2461 case HloOpcode::kRsqrt:
2462 case HloOpcode::kSign:
2463 case HloOpcode::kSin:
2464 case HloOpcode::kSqrt:
2465 case HloOpcode::kCbrt:
2466 case HloOpcode::kTanh:
2467 return [this, hlo, &operand_to_generator](
2468 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2469 TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2470 operand_to_generator.at(hlo->operand(0))(index));
2471 return EmitUnaryOp(hlo, operand_value);
2472 };
2473 case HloOpcode::kAdd:
2474 case HloOpcode::kAnd:
2475 case HloOpcode::kAtan2:
2476 case HloOpcode::kCompare:
2477 case HloOpcode::kComplex:
2478 case HloOpcode::kDivide:
2479 case HloOpcode::kMaximum:
2480 case HloOpcode::kMinimum:
2481 case HloOpcode::kMultiply:
2482 case HloOpcode::kOr:
2483 case HloOpcode::kXor:
2484 case HloOpcode::kPower:
2485 case HloOpcode::kRemainder:
2486 case HloOpcode::kShiftLeft:
2487 case HloOpcode::kShiftRightArithmetic:
2488 case HloOpcode::kShiftRightLogical:
2489 case HloOpcode::kSubtract:
2490 return [this, hlo, &operand_to_generator](
2491 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2492 const HloInstruction* lhs = hlo->operand(0);
2493 const HloInstruction* rhs = hlo->operand(1);
2494 TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value,
2495 operand_to_generator.at(lhs)(index));
2496 TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value,
2497 operand_to_generator.at(rhs)(index));
2498 return EmitBinaryOp(hlo, lhs_value, rhs_value);
2499 };
2500 case HloOpcode::kSelect:
2501 return [this, hlo, &operand_to_generator](
2502 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2503 return EmitElementalSelect(hlo, operand_to_generator, index);
2504 };
2505 case HloOpcode::kClamp:
2506 return [this, hlo, &operand_to_generator](
2507 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2508 return EmitElementalClamp(hlo, operand_to_generator, index);
2509 };
2510 case HloOpcode::kReducePrecision:
2511 return [this, hlo, &operand_to_generator](
2512 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2513 TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2514 operand_to_generator.at(hlo->operand(0))(index));
2515 return EmitReducePrecision(hlo, operand_value);
2516 };
2517 case HloOpcode::kConcatenate:
2518 return [this, hlo, &operand_to_generator](
2519 const IrArray::Index target_index) -> StatusOr<llvm::Value*> {
2520 return EmitElementalConcatenate(hlo, operand_to_generator,
2521 target_index);
2522 };
2523 case HloOpcode::kReverse:
2524 return [this, hlo, &operand_to_generator](
2525 const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2526 const HloInstruction* operand = hlo->operand(0);
2527 std::vector<llvm::Value*> source_multi_index = target_index.multidim();
2528 for (int64_t dim : hlo->dimensions()) {
2529 source_multi_index[dim] = Sub(target_index.GetConstantWithIndexType(
2530 hlo->shape().dimensions(dim) - 1),
2531 target_index[dim]);
2532 }
2533 llvm_ir::IrArray::Index source_index(
2534 source_multi_index, operand->shape(), target_index.GetType());
2535 return operand_to_generator.at(operand)(source_index);
2536 };
2537 case HloOpcode::kBroadcast:
2538 return [this, hlo, &operand_to_generator](
2539 const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2540 const HloInstruction* operand = hlo->operand(0);
2541 // The `dimensions` member of the broadcast instruction maps from
2542 // input dimensions to output dimensions.
2543 return operand_to_generator.at(operand)(
2544 target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(),
2545 hlo->dimensions(), b_));
2546 };
2547 case HloOpcode::kIota:
2548 return [this, hlo](
2549 const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2550 auto* iota = Cast<HloIotaInstruction>(hlo);
2551 PrimitiveType element_type = iota->shape().element_type();
2552 IrArray::Index elem_index =
2553 iota->shape().rank() > 1
2554 ? target_index.SourceIndexOfBroadcast(
2555 iota->shape(),
2556 ShapeUtil::MakeShapeWithDescendingLayout(
2557 element_type,
2558 {iota->shape().dimensions(iota->iota_dimension())}),
2559 {iota->iota_dimension()}, b_)
2560 : target_index;
2561 llvm::Value* elem_index_linear = elem_index.linear();
2562 if (elem_index_linear == nullptr) {
2563 std::vector<int64_t> iota_bound = {
2564 iota->shape().dimensions(iota->iota_dimension())};
2565 elem_index_linear = elem_index.Linearize(iota_bound, b_);
2566 }
2567 Shape component_shape =
2568 ShapeUtil::ElementIsComplex(iota->shape())
2569 ? ShapeUtil::ComplexComponentShape(iota->shape())
2570 : iota->shape();
2571 PrimitiveType component_element_type = component_shape.element_type();
2572 llvm::Value* iota_result;
2573 if (primitive_util::IsIntegralType(component_element_type)) {
2574 iota_result = b_->CreateIntCast(
2575 elem_index_linear,
2576 llvm_ir::PrimitiveTypeToIrType(component_element_type, module_),
2577 /*isSigned=*/false);
2578 } else {
2579 TF_RET_CHECK(
2580 primitive_util::IsFloatingPointType(component_element_type))
2581 << component_element_type;
2582 llvm::Type* float_ir_type;
2583 if (component_element_type == BF16) {
2584 float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_);
2585 } else {
2586 float_ir_type =
2587 llvm_ir::PrimitiveTypeToIrType(component_element_type, module_);
2588 }
2589 llvm::Value* float_val =
2590 b_->CreateUIToFP(elem_index_linear, float_ir_type);
2591 if (component_element_type == BF16) {
2592 TF_ASSIGN_OR_RETURN(iota_result, EmitF32ToBF16(float_val, b_));
2593 } else {
2594 iota_result = float_val;
2595 }
2596 }
2597 if (ShapeUtil::ElementIsComplex(iota->shape())) {
2598 return EmitComposeComplex(iota, iota_result, nullptr);
2599 } else {
2600 return iota_result;
2601 }
2602 };
2603 case HloOpcode::kSlice:
2604 return [this, hlo, &operand_to_generator](
2605 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2606 IrArray::Index sliced_index = index.SourceIndexOfSlice(
2607 /*operand_shape=*/hlo->operand(0)->shape(),
2608 /*starts=*/hlo->slice_starts(),
2609 /*strides=*/hlo->slice_strides(), /*builder=*/b_);
2610 return operand_to_generator.at(hlo->operand(0))(sliced_index);
2611 };
2612 case HloOpcode::kDynamicSlice:
2613 return [this, hlo, &operand_to_generator](
2614 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2615 return EmitElementalDynamicSlice(hlo, operand_to_generator, index);
2616 };
2617
2618 case HloOpcode::kGather:
2619 return [this, hlo, &operand_to_generator](
2620 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2621 return EmitElementalGather(hlo, operand_to_generator, index);
2622 };
2623 case HloOpcode::kDynamicUpdateSlice:
2624 return [this, hlo, &operand_to_generator](
2625 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2626 return EmitElementalDynamicUpdateSlice(hlo, operand_to_generator,
2627 index);
2628 };
2629 case HloOpcode::kBitcast:
2630 CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
2631 ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
2632 return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2633 const HloInstruction* operand = hlo->operand(0);
2634 return operand_to_generator.at(operand)(
2635 GetSourceIndexOfBitcast(index, hlo));
2636 };
2637 case HloOpcode::kReshape:
2638 CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
2639 ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
2640 return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2641 const HloInstruction* operand = hlo->operand(0);
2642 return operand_to_generator.at(operand)(
2643 index.SourceIndexOfReshape(hlo->shape(), operand->shape(), b_));
2644 };
2645 case HloOpcode::kCopy:
2646 return [hlo, &operand_to_generator](
2647 const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2648 IrArray::Index source_index(target_index.multidim(),
2649 hlo->operand(0)->shape(),
2650 target_index.GetType());
2651 TF_ASSIGN_OR_RETURN(
2652 llvm::Value * operand_value,
2653 operand_to_generator.at(hlo->operand(0))(source_index));
2654 return operand_value;
2655 };
2656 case HloOpcode::kTranspose:
2657 return [this, hlo,
2658 &operand_to_generator](const IrArray::Index& target_index) {
2659 return operand_to_generator.at(hlo->operand(0))(
2660 target_index.SourceIndexOfTranspose(
2661 hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions()));
2662 };
2663 case HloOpcode::kPad:
2664 return [this, hlo, &operand_to_generator](
2665 const IrArray::Index& padded_index) -> StatusOr<llvm::Value*> {
2666 return EmitElementalPad(hlo, operand_to_generator, padded_index);
2667 };
2668
2669 case HloOpcode::kDot:
2670 return [this, hlo,
2671 &operand_to_generator](const IrArray::Index& dot_result_index)
2672 -> StatusOr<llvm::Value*> {
2673 return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
2674 };
2675 case HloOpcode::kMap:
2676 return [this, hlo, &operand_to_generator](
2677 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2678 std::vector<llvm::Value*> operands;
2679 for (int i = 0; i < hlo->operand_count(); i++) {
2680 TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2681 operand_to_generator.at(hlo->operand(i))(index));
2682 operands.push_back(operand_value);
2683 }
2684 return EmitElementalMap(Cast<HloMapInstruction>(hlo), operands);
2685 };
2686 case HloOpcode::kReduceWindow:
2687 return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2688 auto reduce_window_instr = Cast<HloReduceWindowInstruction>(hlo);
2689 std::vector<llvm_ir::ElementGenerator> input_generators;
2690 for (const HloInstruction* instr : reduce_window_instr->inputs()) {
2691 input_generators.push_back(operand_to_generator.at(instr));
2692 }
2693
2694 std::vector<llvm_ir::ElementGenerator> initial_value_generators;
2695 for (const HloInstruction* instr : reduce_window_instr->init_values()) {
2696 initial_value_generators.push_back(operand_to_generator.at(instr));
2697 }
2698 return EmitElementalReduceWindow(
2699 Cast<HloReduceWindowInstruction>(hlo), std::move(input_generators),
2700 std::move(initial_value_generators), index);
2701 };
2702 case HloOpcode::kReduce:
2703 return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2704 auto reduce_instr = Cast<HloReduceInstruction>(hlo);
2705 std::vector<llvm_ir::ElementGenerator> input_generators;
2706 for (const HloInstruction* instr : reduce_instr->inputs()) {
2707 input_generators.push_back(operand_to_generator.at(instr));
2708 }
2709
2710 std::vector<llvm_ir::ElementGenerator> initial_value_generators;
2711 for (const HloInstruction* instr : reduce_instr->init_values()) {
2712 initial_value_generators.push_back(operand_to_generator.at(instr));
2713 }
2714 return EmitElementalReduce(reduce_instr, std::move(input_generators),
2715 std::move(initial_value_generators), index);
2716 };
2717 case HloOpcode::kConvolution:
2718 return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2719 return EmitConvolution(hlo, operand_to_generator, index);
2720 };
2721 default:
2722 return [hlo](const IrArray::Index& index) {
2723 return Unimplemented("Unhandled opcode for elemental IR emission: %s",
2724 HloOpcodeString(hlo->opcode()));
2725 };
2726 }
2727 }
2728
EmitExtractReal(llvm::Value * value)2729 llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) {
2730 return ExtractValue(value, {0});
2731 }
2732
EmitExtractImag(llvm::Value * value)2733 llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) {
2734 return ExtractValue(value, {1});
2735 }
2736
EmitComposeComplex(const HloInstruction * op,llvm::Value * real,llvm::Value * imag)2737 llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
2738 llvm::Value* real,
2739 llvm::Value* imag) {
2740 auto cplx_type =
2741 llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
2742 auto complex =
2743 InsertValue(llvm::ConstantAggregateZero::get(cplx_type), real, {0});
2744 if (imag != nullptr) {
2745 complex = InsertValue(complex, imag, {1});
2746 }
2747 return complex;
2748 }
2749
EmitMulAdd(llvm::Value * lhs,llvm::Value * rhs,llvm::Value * accumulator,xla::PrimitiveType primitive_type)2750 llvm::Value* ElementalIrEmitter::EmitMulAdd(llvm::Value* lhs, llvm::Value* rhs,
2751 llvm::Value* accumulator,
2752 xla::PrimitiveType primitive_type) {
2753 if (primitive_util::IsComplexType(primitive_type)) {
2754 llvm::Value* product_real =
2755 FSub(FMul(EmitExtractReal(lhs), EmitExtractReal(rhs)),
2756 FMul(EmitExtractImag(lhs), EmitExtractImag(rhs)));
2757 llvm::Value* product_imag =
2758 FAdd(FMul(EmitExtractReal(lhs), EmitExtractImag(rhs)),
2759 FMul(EmitExtractImag(lhs), EmitExtractReal(rhs)));
2760 llvm::Value* next_accumulator = InsertValue(
2761 accumulator, FAdd(EmitExtractReal(accumulator), product_real), {0});
2762 return InsertValue(next_accumulator,
2763 FAdd(EmitExtractImag(accumulator), product_imag), {1});
2764 } else if (primitive_util::IsFloatingPointType(primitive_type)) {
2765 return FAdd(accumulator, FPCast(FMul(lhs, rhs), accumulator->getType()));
2766 } else if (primitive_type == PRED) {
2767 return Or(accumulator, And(lhs, rhs));
2768 }
2769 return Add(accumulator, Mul(lhs, rhs));
2770 }
2771
EmitElementalMap(const HloMapInstruction * map_instr,absl::Span<llvm::Value * const> elemental_operands)2772 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalMap(
2773 const HloMapInstruction* map_instr,
2774 absl::Span<llvm::Value* const> elemental_operands) {
2775 TF_ASSIGN_OR_RETURN(
2776 std::vector<llvm::Value*> values,
2777 EmitThreadLocalCall(*map_instr->to_apply(), elemental_operands,
2778 llvm_ir::IrName(map_instr), /*is_reducer=*/false));
2779 CHECK_EQ(values.size(), 1);
2780 return values[0];
2781 }
2782
EmitElementalReduceWindow(const HloReduceWindowInstruction * reduce_window,std::vector<llvm_ir::ElementGenerator> input_generators,std::vector<llvm_ir::ElementGenerator> initial_value_generators,const llvm_ir::IrArray::Index & index)2783 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduceWindow(
2784 const HloReduceWindowInstruction* reduce_window,
2785 std::vector<llvm_ir::ElementGenerator> input_generators,
2786 std::vector<llvm_ir::ElementGenerator> initial_value_generators,
2787 const llvm_ir::IrArray::Index& index) {
2788 // Pseudocode:
2789 // for each index I in output
2790 // value = init_value
2791 // for each index W in window
2792 // for each dimension i from 0 to rank - 1
2793 // (input index I)[i] = O[i] * stride[i] + W[i] - pad_low[i]
2794 // if I in bounds of input
2795 // value = function(value, input[I])
2796 // output[O] = value
2797 int64_t input_count = reduce_window->input_count();
2798 std::vector<PrimitiveType> operand_element_types;
2799 std::vector<llvm::Type*> accum_types;
2800 std::vector<llvm::Value*> accum_ptrs;
2801 for (int64_t operand_index = 0; operand_index < input_count;
2802 ++operand_index) {
2803 auto operand = reduce_window->inputs()[operand_index];
2804 PrimitiveType operand_element_type = operand->shape().element_type();
2805 operand_element_types.push_back(operand_element_type);
2806 llvm::Type* llvm_type =
2807 llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_);
2808 accum_types.push_back(llvm_type);
2809 llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
2810 llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
2811 "reduce_window_accum_ptr", b_);
2812 accum_ptrs.push_back(accum_ptr);
2813 {
2814 auto initial_value_generator = initial_value_generators[operand_index];
2815 TF_ASSIGN_OR_RETURN(
2816 llvm::Value* const init_value,
2817 initial_value_generator(llvm_ir::IrArray::Index(index.GetType())));
2818 Store(init_value, accum_ptr);
2819 }
2820 }
2821
2822 llvm::Type* index_type = index.GetType();
2823 auto index_typed_const = [&](uint64_t c) -> llvm::Constant* {
2824 return index.GetConstantWithIndexType(c);
2825 };
2826
2827 const Window& window = reduce_window->window();
2828 llvm_ir::ForLoopNest loops(IrName(reduce_window), b_, index_type);
2829 std::vector<int64_t> window_size;
2830 const auto& dimensions = window.dimensions();
2831 window_size.reserve(dimensions.size());
2832 for (const auto& dim : dimensions) {
2833 window_size.push_back(dim.size());
2834 }
2835 const IrArray::Index window_index = loops.AddLoopsForShape(
2836 ShapeUtil::MakeShape(operand_element_types[0], window_size), "window");
2837 CHECK_EQ(window_index.size(), index.size());
2838
2839 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
2840
2841 std::vector<llvm::Value*> input_multi_index(index.size());
2842 llvm::Value* in_bounds = b_->getInt1(true);
2843 for (size_t i = 0; i < index.size(); ++i) {
2844 llvm::Value* stridden_index =
2845 NSWMul(index[i], index_typed_const(window.dimensions(i).stride()));
2846 input_multi_index[i] = NSWSub(
2847 NSWAdd(
2848 stridden_index,
2849 NSWMul(window_index[i],
2850 index_typed_const(window.dimensions(i).window_dilation()))),
2851 index_typed_const(window.dimensions(i).padding_low()));
2852
2853 // We need to verify that we are not in the dilated base area.
2854 llvm::Value* dilation_condition =
2855 ICmpEQ(SRem(input_multi_index[i],
2856 index_typed_const(window.dimensions(i).base_dilation())),
2857 index_typed_const(0));
2858 in_bounds = And(in_bounds, dilation_condition);
2859
2860 // Apply base dilation to the index.
2861 input_multi_index[i] =
2862 SDiv(input_multi_index[i],
2863 index_typed_const(window.dimensions(i).base_dilation()));
2864
2865 // We must check whether 0 <= input_multi_index[i] < bound, as
2866 // otherwise we are in the pad and so can skip the computation. This
2867 // comparison is equivalent to the unsigned comparison
2868 // input_multi_index[i] < bound, as a negative value wraps to a large
2869 // positive value.
2870 in_bounds =
2871 And(in_bounds,
2872 ICmpULT(input_multi_index[i],
2873 index_typed_const(
2874 reduce_window->inputs()[0]->shape().dimensions(i))));
2875 }
2876
2877 llvm_ir::LlvmIfData if_data =
2878 llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
2879 SetToFirstInsertPoint(if_data.true_block, b_);
2880
2881 // We are not in pad, so do the computation.
2882 std::vector<llvm::Value*> input_values(reduce_window->operand_count());
2883 IrArray::Index input_index(input_multi_index,
2884 reduce_window->inputs()[0]->shape(), index_type);
2885 for (int64_t operand_idx = 0; operand_idx < input_count; ++operand_idx) {
2886 TF_ASSIGN_OR_RETURN(llvm::Value * input_value,
2887 input_generators[operand_idx](input_index));
2888 input_values[input_count + operand_idx] = input_value;
2889 input_values[operand_idx] =
2890 Load(llvm::cast<llvm::AllocaInst>(accum_ptrs[operand_idx])
2891 ->getAllocatedType(),
2892 accum_ptrs[operand_idx]);
2893 }
2894 TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accum_values,
2895 EmitThreadLocalCall(*reduce_window->to_apply(),
2896 input_values, "reducer_function",
2897 /*is_reducer=*/true));
2898
2899 for (int64_t operand_idx = 0; operand_idx < accum_values.size();
2900 ++operand_idx) {
2901 Store(accum_values[operand_idx], accum_ptrs[operand_idx]);
2902 }
2903
2904 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
2905 return EmitAccumResult(accum_ptrs, accum_types,
2906 reduce_window->shape().IsTuple());
2907 }
2908
EmitElementalReduce(const HloReduceInstruction * reduce,std::vector<llvm_ir::ElementGenerator> input_generators,std::vector<llvm_ir::ElementGenerator> initial_value_generators,const llvm_ir::IrArray::Index & index)2909 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduce(
2910 const HloReduceInstruction* reduce,
2911 std::vector<llvm_ir::ElementGenerator> input_generators,
2912 std::vector<llvm_ir::ElementGenerator> initial_value_generators,
2913 const llvm_ir::IrArray::Index& index) {
2914 const Shape& out_shape = reduce->shape();
2915 bool is_variadic = !out_shape.IsArray();
2916 int accumulators_count = 1;
2917 if (is_variadic) {
2918 CHECK(out_shape.IsTuple());
2919 accumulators_count = out_shape.tuple_shapes_size();
2920 }
2921
2922 absl::Span<const int64_t> reduced_dimensions(reduce->dimensions());
2923
2924 std::vector<llvm::Value*> accumulator_addrs;
2925 std::vector<llvm::Type*> accumulator_types;
2926 llvm::Type* index_type = index.GetType();
2927 for (int i = 0; i < accumulators_count; i++) {
2928 const Shape& element_shape =
2929 is_variadic ? out_shape.tuple_shapes(i) : out_shape;
2930 PrimitiveType accumulator_type = element_shape.element_type();
2931 llvm::Type* accumulator_llvm_type =
2932 llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_);
2933 accumulator_types.push_back(accumulator_llvm_type);
2934
2935 // Initialize an accumulator with init_value.
2936 llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
2937 accumulator_llvm_type, "accumulator_" + std::to_string(i), b());
2938 TF_ASSIGN_OR_RETURN(
2939 llvm::Value* const init_value,
2940 initial_value_generators[i](llvm_ir::IrArray::Index(index_type)));
2941 Store(init_value, accumulator_addr);
2942 accumulator_addrs.push_back(accumulator_addr);
2943 }
2944
2945 // The enclosing loops go over all the target elements. Now we have to compute
2946 // the actual target element. For this, we build a new loop nest to iterate
2947 // over all the reduction dimensions in the argument.
2948 // AddLoopsForShapeOnDimensions will return an Index where induction Value*s
2949 // are placed for each dimension in dimensions, and all the rest are nullptrs.
2950 llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), b(), index_type);
2951 const HloInstruction* arg = reduce->operand(0);
2952 std::vector<llvm::Value*> input_multi_index =
2953 loops.AddLoopsForShapeOnDimensions(arg->shape(), reduced_dimensions,
2954 "reduction_dim");
2955
2956 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b());
2957
2958 // Build a full index for the input argument, using input_multi_index as the
2959 // base. In input_multi_index only the reduction dimensions are filled in. We
2960 // fill in the rest of the dimensions with induction Value*s taken from
2961 // 'index' which iterates over the target array. See the high-level
2962 // description in the XLA documentation for details.
2963 auto it = index.begin();
2964
2965 for (auto& i : input_multi_index) {
2966 if (i == nullptr) {
2967 i = *it++;
2968 }
2969 }
2970 CHECK(index.end() == it);
2971 llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
2972 index_type);
2973
2974 std::vector<llvm::Value*> reduction_operands;
2975 for (llvm::Value* accum : accumulator_addrs) {
2976 llvm::Value* accum_value =
2977 Load(llvm::cast<llvm::AllocaInst>(accum)->getAllocatedType(), accum);
2978 reduction_operands.push_back(accum_value);
2979 }
2980
2981 for (int i = 0; i < accumulators_count; i++) {
2982 TF_ASSIGN_OR_RETURN(llvm::Value* const input_element,
2983 input_generators[i](input_index));
2984 reduction_operands.push_back(input_element);
2985 }
2986
2987 TF_ASSIGN_OR_RETURN(
2988 std::vector<llvm::Value*> results,
2989 EmitThreadLocalCall(*reduce->to_apply(), reduction_operands,
2990 "reduce_function", /*is_reducer=*/true));
2991
2992 CHECK(results.size() == accumulators_count);
2993 for (int i = 0; i < accumulators_count; i++) {
2994 Store(results[i], accumulator_addrs[i]);
2995 }
2996 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b());
2997 return EmitAccumResult(accumulator_addrs, accumulator_types, is_variadic);
2998 }
2999
EmitAccumResult(absl::Span<llvm::Value * const> accumulator_addrs,llvm::ArrayRef<llvm::Type * > accumulator_types,bool is_variadic)3000 StatusOr<llvm::Value*> ElementalIrEmitter::EmitAccumResult(
3001 absl::Span<llvm::Value* const> accumulator_addrs,
3002 llvm::ArrayRef<llvm::Type*> accumulator_types, bool is_variadic) {
3003 TF_RET_CHECK(accumulator_addrs.size() == accumulator_types.size());
3004 if (is_variadic) {
3005 // Emit a structure, as that what the LoopEmitter expects.
3006 llvm::Value* returned_structure = llvm::UndefValue::get(
3007 llvm::StructType::get(b()->getContext(), accumulator_types));
3008 for (int64_t i = 0; i < accumulator_addrs.size(); i++) {
3009 llvm::Value* accumulator_value =
3010 Load(accumulator_types[i], accumulator_addrs[i]);
3011 returned_structure =
3012 b()->CreateInsertValue(returned_structure, accumulator_value, i);
3013 }
3014 return returned_structure;
3015 } else {
3016 CHECK_EQ(accumulator_addrs.size(), 1);
3017 return Load(accumulator_types[0], accumulator_addrs[0]);
3018 }
3019 }
3020
EmitConvolution(const HloInstruction * convolution,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)3021 StatusOr<llvm::Value*> ElementalIrEmitter::EmitConvolution(
3022 const HloInstruction* convolution,
3023 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
3024 const llvm_ir::IrArray::Index& index) {
3025 TF_RET_CHECK(convolution->batch_group_count() == 1);
3026 const HloInstruction* lhs = convolution->operand(0);
3027 const auto& input_generator = operand_to_generator.at(lhs);
3028 const HloInstruction* rhs = convolution->operand(1);
3029 const auto& kernel_generator = operand_to_generator.at(rhs);
3030 const Window& window = convolution->window();
3031
3032 const ConvolutionDimensionNumbers& dnums =
3033 convolution->convolution_dimension_numbers();
3034 int num_spatial_dims = dnums.output_spatial_dimensions_size();
3035 std::vector<llvm::Value*> output_spatial(num_spatial_dims);
3036 for (int i = 0; i < num_spatial_dims; ++i) {
3037 output_spatial[i] = index[dnums.output_spatial_dimensions(i)];
3038 }
3039 llvm::Value* output_feature = index[dnums.output_feature_dimension()];
3040 llvm::Value* batch = index[dnums.output_batch_dimension()];
3041
3042 // We will accumulate the products into this sum to calculate the output entry
3043 // at the given index.
3044 PrimitiveType lhs_element_type = lhs->shape().element_type();
3045 llvm::Type* lhs_llvm_type =
3046 llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
3047 // Upcast the accumulator to F32 from F16 for increased precision.
3048 llvm::Type* accumulator_type =
3049 lhs_element_type == F16 ? b_->getFloatTy() : lhs_llvm_type;
3050 llvm::AllocaInst* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
3051 accumulator_type, "convolution_sum_address", b_);
3052 llvm::Value* constant_zero = llvm::Constant::getNullValue(accumulator_type);
3053 Store(constant_zero, sum_address);
3054
3055 llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), b_);
3056 std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
3057 for (int i = 0; i < num_spatial_dims; ++i) {
3058 kernel_spatial[i] =
3059 loops
3060 .AddLoop(
3061 0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)),
3062 absl::StrCat("k", i))
3063 ->GetIndVarValue();
3064 }
3065 const int64_t input_group_size =
3066 rhs->shape().dimensions(dnums.kernel_input_feature_dimension());
3067 const int64_t feature_group_count = convolution->feature_group_count();
3068 const int64_t output_group_size =
3069 rhs->shape().dimensions(dnums.kernel_output_feature_dimension()) /
3070 feature_group_count;
3071 llvm::Value* input_feature =
3072 loops.AddLoop(0, input_group_size, "iz")->GetIndVarValue();
3073
3074 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
3075
3076 llvm::Value* group_id = SDiv(output_feature, b_->getInt64(output_group_size));
3077 llvm::Value* lhs_input_feature =
3078 NSWAdd(input_feature, NSWMul(group_id, b_->getInt64(input_group_size)));
3079
3080 // Calculate the spatial index in the input array, taking striding, dilation
3081 // and padding into account. An index in the padding will be out of the bounds
3082 // of the array.
3083 const auto calculate_input_index = [this](llvm::Value* output_index,
3084 llvm::Value* kernel_index,
3085 const WindowDimension& window_dim) {
3086 llvm::Value* strided_index =
3087 NSWMul(output_index, b_->getInt64(window_dim.stride()));
3088 llvm::Value* dilated_kernel_index =
3089 NSWMul(kernel_index, b_->getInt64(window_dim.window_dilation()));
3090 return NSWSub(NSWAdd(strided_index, dilated_kernel_index),
3091 b_->getInt64(window_dim.padding_low()));
3092 };
3093 std::vector<llvm::Value*> input_spatial(num_spatial_dims);
3094 for (int i = 0; i < num_spatial_dims; ++i) {
3095 input_spatial[i] = calculate_input_index(
3096 output_spatial[i], kernel_spatial[i], window.dimensions(i));
3097 }
3098
3099 // We need to check if 0 <= input dim < bound, as otherwise we are in the
3100 // padding so that we can skip the computation. That is equivalent to input
3101 // dim < bound as an *unsigned* comparison, since a negative value will wrap
3102 // to a large positive value. The input dim is dilated, so we need to dilate
3103 // the bound as well to match.
3104
3105 // Also need to check that the input coordinates are not in one of the
3106 // holes created by base dilation.
3107 const auto not_in_hole = [&](llvm::Value* input_index,
3108 int64_t base_dilation) {
3109 llvm::Value* remainder = SRem(input_index, b_->getInt64(base_dilation));
3110 return ICmpEQ(remainder, b_->getInt64(0));
3111 };
3112
3113 llvm::Value* in_bounds_condition = b_->getInt1(true);
3114 for (int i = 0; i < num_spatial_dims; ++i) {
3115 llvm::ConstantInt* input_bound = b_->getInt64(window_util::DilatedBound(
3116 lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
3117 window.dimensions(i).base_dilation()));
3118 llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound);
3119 llvm::Value* dim_not_in_hole =
3120 not_in_hole(input_spatial[i], window.dimensions(i).base_dilation());
3121 llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole);
3122 in_bounds_condition = And(in_bounds_condition, dim_ok);
3123 }
3124
3125 // Now we need to map the dilated base coordinates back to the actual
3126 // data indices on the lhs.
3127 const auto undilate = [&](llvm::Value* input_index, int64_t base_dilation) {
3128 return SDiv(input_index, b_->getInt64(base_dilation));
3129 };
3130 for (int i = 0; i < num_spatial_dims; ++i) {
3131 input_spatial[i] =
3132 undilate(input_spatial[i], window.dimensions(i).base_dilation());
3133 }
3134
3135 llvm_ir::LlvmIfData if_data =
3136 llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", b_);
3137 SetToFirstInsertPoint(if_data.true_block, b_);
3138
3139 // We are not in the padding, so carry out the computation.
3140 int num_dims = num_spatial_dims + 2;
3141 std::vector<llvm::Value*> input_multi_index(num_dims);
3142 for (int i = 0; i < num_spatial_dims; ++i) {
3143 input_multi_index[dnums.input_spatial_dimensions(i)] = input_spatial[i];
3144 }
3145 input_multi_index[dnums.input_feature_dimension()] = lhs_input_feature;
3146 input_multi_index[dnums.input_batch_dimension()] = batch;
3147
3148 std::vector<llvm::Value*> kernel_multi_index(num_dims);
3149 for (int i = 0; i < num_spatial_dims; ++i) {
3150 kernel_multi_index[dnums.kernel_spatial_dimensions(i)] =
3151 window.dimensions(i).window_reversal()
3152 ? NSWSub(b_->getInt64(window.dimensions(i).size() - 1),
3153 kernel_spatial[i])
3154 : kernel_spatial[i];
3155 }
3156
3157 kernel_multi_index[dnums.kernel_input_feature_dimension()] = input_feature;
3158 kernel_multi_index[dnums.kernel_output_feature_dimension()] = output_feature;
3159
3160 llvm_ir::IrArray::Index input_index(input_multi_index, lhs->shape(),
3161 b_->getInt64Ty());
3162 TF_ASSIGN_OR_RETURN(llvm::Value* const input_value,
3163 input_generator(input_index));
3164 llvm_ir::IrArray::Index kernel_index(kernel_multi_index, rhs->shape(),
3165 b_->getInt64Ty());
3166 TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value,
3167 kernel_generator(kernel_index));
3168 llvm::Value* sum =
3169 EmitMulAdd(input_value, kernel_value,
3170 Load(sum_address->getAllocatedType(), sum_address),
3171 convolution->shape().element_type());
3172 Store(sum, sum_address);
3173
3174 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
3175 return FPCast(Load(sum_address->getAllocatedType(), sum_address),
3176 lhs_llvm_type);
3177 }
3178
3179 // Evaluate polynomial using Horner's method.
EvaluatePolynomial(llvm::Type * type,llvm::Value * x,absl::Span<const double> coefficients)3180 StatusOr<llvm::Value*> ElementalIrEmitter::EvaluatePolynomial(
3181 llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients) {
3182 llvm::Value* poly = llvm::ConstantFP::get(type, 0.0);
3183 for (const double c : coefficients) {
3184 poly = FAdd(FMul(poly, x), llvm::ConstantFP::get(type, c));
3185 }
3186 return poly;
3187 }
3188
3189 } // namespace xla
3190