xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/elemental_ir_emitter.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/strings/str_cat.h"
28 #include "llvm/IR/BasicBlock.h"
29 #include "llvm/IR/Constants.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/Intrinsics.h"
32 #include "llvm/Support/MathExtras.h"
33 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
34 #include "tensorflow/compiler/xla/primitive_util.h"
35 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
36 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
37 #include "tensorflow/compiler/xla/service/hlo_module.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
40 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
42 #include "tensorflow/compiler/xla/shape_util.h"
43 #include "tensorflow/compiler/xla/status_macros.h"
44 #include "tensorflow/compiler/xla/statusor.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/compiler/xla/window_util.h"
48 #include "tensorflow/compiler/xla/xla_data.pb.h"
49 #include "tensorflow/core/platform/logging.h"
50 
51 namespace xla {
52 
53 using absl::StrCat;
54 using llvm_ir::IrArray;
55 using llvm_ir::IrName;
56 using llvm_ir::SetToFirstInsertPoint;
57 
58 namespace {
59 
EmitReducePrecisionIR(PrimitiveType src_ty,llvm::Value * x,int64_t dest_exponent_bits,int64_t dest_mantissa_bits,bool quiet_nans,llvm::IRBuilder<> * b)60 StatusOr<llvm::Value*> EmitReducePrecisionIR(
61     PrimitiveType src_ty, llvm::Value* x, int64_t dest_exponent_bits,
62     int64_t dest_mantissa_bits, bool quiet_nans, llvm::IRBuilder<>* b) {
63   using llvm::APInt;
64 
65   if (!primitive_util::IsFloatingPointType(src_ty)) {
66     return Unimplemented(
67         "ReducePrecision cannot accept non-floating-point type %s.",
68         PrimitiveType_Name(src_ty));
69   }
70 
71   // Integer and float types for casting and constant generation.
72   llvm::Type* float_type = x->getType();
73   int64_t nbits = float_type->getPrimitiveSizeInBits();
74   llvm::IntegerType* int_type = b->getIntNTy(nbits);
75 
76   // SignificandWidth includes the implicit extra bit.
77   int src_mantissa_bits = primitive_util::SignificandWidth(src_ty) - 1;
78   int src_exponent_bits = nbits - 1 - src_mantissa_bits;
79 
80   // Cast the input value to an integer for bitwise manipulation.
81   llvm::Value* x_as_int = b->CreateBitCast(x, int_type);
82 
83   // Clear the sign bit, it does not participate in rounding and we will restore
84   // it later.
85   APInt sign_bit_mask(nbits, 1);
86   sign_bit_mask <<= nbits - 1;
87   llvm::Value* x_abs_bits =
88       b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, ~sign_bit_mask));
89 
90   APInt exp_bits_mask(nbits, 1);
91   exp_bits_mask = ((exp_bits_mask << src_exponent_bits) - 1)
92                   << src_mantissa_bits;
93   auto x_is_nan = b->CreateICmpUGT(
94       x_abs_bits, llvm::ConstantInt::get(int_type, exp_bits_mask));
95 
96   if (dest_mantissa_bits < src_mantissa_bits) {
97     // Last remaining mantissa bit.
98     APInt last_mantissa_bit_mask(nbits, 1);
99     last_mantissa_bit_mask <<= src_mantissa_bits - dest_mantissa_bits;
100 
101     // Compute rounding bias for round-to-nearest with ties to even.  This is
102     // equal to a base value of 0111... plus one bit if the last remaining
103     // mantissa bit is 1.
104     APInt base_rounding_bias = last_mantissa_bit_mask.lshr(1) - 1;
105     llvm::Value* x_last_mantissa_bit = b->CreateLShr(
106         b->CreateAnd(x_as_int,
107                      llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
108         (src_mantissa_bits - dest_mantissa_bits));
109     llvm::Value* x_rounding_bias =
110         b->CreateAdd(x_last_mantissa_bit,
111                      llvm::ConstantInt::get(int_type, base_rounding_bias));
112 
113     // Add rounding bias, and mask out truncated bits.  Note that the case
114     // where adding the rounding bias overflows into the exponent bits is
115     // correct; the non-masked mantissa bits will all be zero, and the
116     // exponent will be incremented by one.
117     APInt truncation_mask = ~(last_mantissa_bit_mask - 1);
118     llvm::Value* x_rounded = b->CreateAdd(x_as_int, x_rounding_bias);
119     x_rounded = b->CreateAnd(x_rounded,
120                              llvm::ConstantInt::get(int_type, truncation_mask));
121     if (quiet_nans) {
122       x_as_int = b->CreateSelect(x_is_nan, x_as_int, x_rounded);
123     } else {
124       x_as_int = x_rounded;
125     }
126   }
127 
128   if (dest_exponent_bits < src_exponent_bits) {
129     // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
130     // significant bit -- is equal to 1.0f for all exponent sizes.  Adding
131     // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
132     // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
133     // exponent (corresponding to 0.0f).
134     //
135     // Thus, the f32 exponent corresponding to the highest non-infinite
136     // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
137     // exponent corresponding to the lowest exponent for a bit size of n is
138     // (2^7-1) - 2^(n-1)-1.
139     //
140     // Note that we have already checked that exponents_bits >= 1.
141     APInt exponent_bias(nbits, 1);
142     exponent_bias = (exponent_bias << (src_exponent_bits - 1)) - 1;
143 
144     APInt reduced_exponent_bias(nbits, 1);
145     reduced_exponent_bias =
146         (reduced_exponent_bias << (dest_exponent_bits - 1)) - 1;
147 
148     APInt reduced_max_exponent = exponent_bias + reduced_exponent_bias;
149     APInt reduced_min_exponent = exponent_bias - reduced_exponent_bias;
150 
151     // Do we overflow or underflow?
152     llvm::Value* x_exponent =
153         b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, exp_bits_mask));
154     llvm::Value* x_overflows = b->CreateICmpUGT(
155         x_exponent, llvm::ConstantInt::get(
156                         int_type, reduced_max_exponent << src_mantissa_bits));
157     llvm::Value* x_underflows = b->CreateICmpULE(
158         x_exponent, llvm::ConstantInt::get(
159                         int_type, reduced_min_exponent << src_mantissa_bits));
160 
161     // Compute appropriately-signed values of zero and infinity.
162     llvm::Value* x_signed_zero =
163         b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, sign_bit_mask));
164     llvm::Value* x_signed_inf = b->CreateOr(
165         x_signed_zero, llvm::ConstantInt::get(int_type, exp_bits_mask));
166 
167     // Force to zero or infinity if overflow or underflow.  (Note that this
168     // truncates all denormal values to zero, rather than rounding them.)
169     x_as_int = b->CreateSelect(x_overflows, x_signed_inf, x_as_int);
170     x_as_int = b->CreateSelect(x_underflows, x_signed_zero, x_as_int);
171   }
172 
173   // Cast the result back to a floating-point type.
174   llvm::Value* result = b->CreateBitCast(x_as_int, float_type);
175 
176   // Correct result for NaN inputs.
177   //
178   // The exponent handling will "normalize" NaN values to infinities, which is
179   // undesirable (except in the case with no mantissa bits, in which case it
180   // is mandatory).  This logic also handles cases where mantissa-rounding
181   // causes a NaN's mantissa to overflow into the exponent bits, which would
182   // otherwise create an erroneous zero value.
183 
184   if (dest_mantissa_bits > 0) {
185     if (quiet_nans) {
186       APInt qnan_mask(nbits, 1);
187       qnan_mask <<= src_mantissa_bits - 1;
188       llvm::Value* x_with_qnan_bit_set =
189           b->CreateOr(x_as_int, llvm::ConstantInt::get(int_type, qnan_mask));
190       x_with_qnan_bit_set = b->CreateBitCast(x_with_qnan_bit_set, float_type);
191       result = b->CreateSelect(x_is_nan, x_with_qnan_bit_set, result);
192     } else {
193       result = b->CreateSelect(x_is_nan, x, result);
194     }
195   } else {
196     result = b->CreateSelect(x_is_nan,
197                              llvm::ConstantFP::getInfinity(float_type), result);
198   }
199 
200   return result;
201 }
202 
EmitF32ToBF16(llvm::Value * f32_value,llvm::IRBuilder<> * b)203 StatusOr<llvm::Value*> EmitF32ToBF16(llvm::Value* f32_value,
204                                      llvm::IRBuilder<>* b) {
205   TF_ASSIGN_OR_RETURN(
206       auto reduced_precision,
207       EmitReducePrecisionIR(
208           /*src_ty=*/F32, f32_value,
209           /*dest_exponent_bits=*/primitive_util::ExponentWidth(BF16),
210           /*dest_mantissa_bits=*/primitive_util::SignificandWidth(BF16) - 1,
211           /*quiet_nans=*/true, b));
212   auto as_int32 = b->CreateBitCast(reduced_precision, b->getInt32Ty());
213   auto shifted = b->CreateLShr(as_int32, 16);
214   auto truncated = b->CreateTrunc(shifted, b->getInt16Ty());
215   return b->CreateBitCast(truncated, b->getInt16Ty());
216 }
217 
EmitBF16ToF32(llvm::Value * bf16_value,llvm::IRBuilder<> * b)218 llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value, llvm::IRBuilder<>* b) {
219   auto as_int16 = b->CreateBitCast(bf16_value, b->getInt16Ty());
220   auto as_int32 = b->CreateZExt(as_int16, b->getInt32Ty());
221   auto shifted = b->CreateShl(as_int32, 16);
222   return b->CreateBitCast(shifted, b->getFloatTy());
223 }
224 
EmitIntegralToFloating(llvm::Value * integer_value,PrimitiveType from_type,PrimitiveType to_type,llvm::Module * module,llvm::IRBuilder<> * b)225 llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value,
226                                     PrimitiveType from_type,
227                                     PrimitiveType to_type, llvm::Module* module,
228                                     llvm::IRBuilder<>* b) {
229   if (primitive_util::IsSignedIntegralType(from_type)) {
230     return b->CreateSIToFP(integer_value,
231                            llvm_ir::PrimitiveTypeToIrType(to_type, module));
232   } else {
233     CHECK(primitive_util::IsUnsignedIntegralType(from_type) ||
234           from_type == PRED);
235     return b->CreateUIToFP(integer_value,
236                            llvm_ir::PrimitiveTypeToIrType(to_type, module));
237   }
238 }
239 
240 }  // namespace
241 
EmitUnaryOp(const HloInstruction * op,llvm::Value * operand_value)242 StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
243     const HloInstruction* op, llvm::Value* operand_value) {
244   if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
245       op->operand(0)->shape().element_type() == PRED) {
246     return EmitIntegerUnaryOp(op, operand_value);
247   } else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) {
248     return EmitComplexUnaryOp(op, operand_value);
249   } else {
250     return EmitFloatUnaryOp(op, operand_value);
251   }
252 }
253 
EmitIntegerUnaryOp(const HloInstruction * op,llvm::Value * operand_value)254 StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
255     const HloInstruction* op, llvm::Value* operand_value) {
256   switch (op->opcode()) {
257     case HloOpcode::kConvert: {
258       PrimitiveType from_type = op->operand(0)->shape().element_type();
259       PrimitiveType to_type = op->shape().element_type();
260       CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED)
261           << from_type;
262       if (from_type == to_type) {
263         return operand_value;
264       }
265       if (to_type == PRED) {
266         return b_->CreateZExt(
267             ICmpNE(operand_value,
268                    llvm::ConstantInt::get(operand_value->getType(), 0)),
269             llvm_ir::PrimitiveTypeToIrType(PRED, module_));
270       }
271       if (primitive_util::IsIntegralType(to_type)) {
272         return IntCast(operand_value,
273                        llvm_ir::PrimitiveTypeToIrType(to_type, module_),
274                        primitive_util::IsSignedIntegralType(from_type));
275       }
276       if (primitive_util::IsFloatingPointType(to_type)) {
277         if (to_type == BF16) {
278           return EmitF32ToBF16(EmitIntegralToFloating(operand_value, from_type,
279                                                       F32, module_, b_),
280                                b_);
281         }
282         return EmitIntegralToFloating(operand_value, from_type, to_type,
283                                       module_, b_);
284       }
285       if (primitive_util::IsComplexType(to_type)) {
286         auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(
287             primitive_util::ComplexComponentType(to_type), module_);
288         if (primitive_util::IsSignedIntegralType(from_type)) {
289           return EmitComposeComplex(
290               op, SIToFP(operand_value, to_ir_component_type), nullptr);
291         }
292         if (primitive_util::IsUnsignedIntegralType(from_type) ||
293             from_type == PRED) {
294           return EmitComposeComplex(
295               op, UIToFP(operand_value, to_ir_component_type), nullptr);
296         }
297       }
298       return Unimplemented("conversion from primitive type %s to %s",
299                            PrimitiveType_Name(from_type),
300                            PrimitiveType_Name(to_type));
301     }
302     case HloOpcode::kBitcastConvert: {
303       PrimitiveType from_type = op->operand(0)->shape().element_type();
304       PrimitiveType to_type = op->shape().element_type();
305       CHECK(primitive_util::IsIntegralType(from_type));
306       if (from_type == to_type) {
307         return operand_value;
308       }
309       if (primitive_util::BitWidth(from_type) ==
310           primitive_util::BitWidth(to_type)) {
311         return BitCast(operand_value,
312                        llvm_ir::PrimitiveTypeToIrType(to_type, module_));
313       }
314       return InvalidArgument(
315           "bitcast conversion from primitive type %s to %s with unequal "
316           "bit-widths (%u versus %u) ",
317           PrimitiveType_Name(from_type), PrimitiveType_Name(to_type),
318           primitive_util::BitWidth(from_type),
319           primitive_util::BitWidth(to_type));
320     }
321     case HloOpcode::kAbs: {
322       bool is_signed =
323           primitive_util::IsSignedIntegralType(op->shape().element_type());
324       if (is_signed) {
325         auto type =
326             llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
327         auto cmp = ICmpSGE(operand_value, GetZero(type));
328         return Select(cmp, operand_value, Neg(operand_value));
329       } else {
330         return operand_value;
331       }
332     }
333     case HloOpcode::kClz: {
334       auto is_zero_undef = b_->getFalse();
335       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctlz,
336                                           {operand_value, is_zero_undef},
337                                           {operand_value->getType()}, b_);
338     }
339     case HloOpcode::kSign: {
340       CHECK(primitive_util::IsSignedIntegralType(op->shape().element_type()))
341           << op->shape().element_type();
342       auto type =
343           llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
344       auto cmp = ICmpEQ(operand_value, GetZero(type));
345       auto ashr = AShr(operand_value, type->getIntegerBitWidth() - 1);
346       return Select(cmp, GetZero(type), Or(ashr, 1));
347     }
348     case HloOpcode::kNegate:
349       return Neg(operand_value);
350     case HloOpcode::kNot: {
351       auto type = op->shape().element_type();
352       if (type == PRED) {
353         // It is not sufficient to just call CreateNot() here because a PRED
354         // is represented as an i8 and the truth value is stored only in the
355         // bottom bit.
356         return b_->CreateZExt(Not(Trunc(operand_value, b_->getInt1Ty())),
357                               llvm_ir::PrimitiveTypeToIrType(PRED, module_));
358       } else if (primitive_util::IsIntegralType(type)) {
359         return Not(operand_value);
360       }
361       return Unimplemented("unary op Not is not defined for type '%d'", type);
362     }
363     case HloOpcode::kPopulationCount: {
364       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctpop,
365                                           {operand_value},
366                                           {operand_value->getType()}, b_);
367     }
368     default:
369       return Unimplemented("unary integer op '%s'",
370                            HloOpcodeString(op->opcode()));
371   }
372 }
373 
EmitFloatUnaryOp(const HloInstruction * op,llvm::Value * operand_value)374 StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
375     const HloInstruction* op, llvm::Value* operand_value) {
376   switch (op->opcode()) {
377     case HloOpcode::kConvert: {
378       PrimitiveType from_type = op->operand(0)->shape().element_type();
379       PrimitiveType to_type = op->shape().element_type();
380       CHECK(primitive_util::IsFloatingPointType(from_type)) << from_type;
381       if (from_type == to_type) {
382         return operand_value;
383       }
384       if (from_type == BF16) {
385         TF_RET_CHECK(to_type != BF16);
386         operand_value = EmitBF16ToF32(operand_value, b_);
387         from_type = F32;
388         if (from_type == to_type) {
389           return operand_value;
390         }
391       }
392       if (primitive_util::IsComplexType(to_type)) {
393         PrimitiveType to_component_type =
394             primitive_util::ComplexComponentType(to_type);
395         if (from_type == to_component_type) {
396           return EmitComposeComplex(op, operand_value, nullptr);
397         }
398         return EmitComposeComplex(
399             op,
400             FPCast(operand_value,
401                    llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)),
402             nullptr);
403       }
404       if (to_type == BF16) {
405         // Cast to F32 first. Other floating point formats are not supported by
406         // EmitReducePrecisionIR.
407         if (from_type != F32) {
408           operand_value = b_->CreateFPCast(
409               operand_value, llvm_ir::PrimitiveTypeToIrType(F32, module_));
410         }
411         return EmitF32ToBF16(operand_value, b_);
412       }
413       if (to_type == PRED) {
414         return b_->CreateZExt(
415             FCmpUNE(operand_value,
416                     llvm::ConstantFP::get(operand_value->getType(), 0.0)),
417             llvm_ir::PrimitiveTypeToIrType(PRED, module_));
418       }
419       auto* to_ir_type = llvm_ir::PrimitiveTypeToIrType(to_type, module_);
420       if (primitive_util::IsFloatingPointType(to_type)) {
421         return FPCast(operand_value, to_ir_type);
422       }
423       auto* from_ir_type = llvm_ir::PrimitiveTypeToIrType(from_type, module_);
424       int to_width = primitive_util::BitWidth(to_type);
425       if (primitive_util::IsSignedIntegralType(to_type)) {
426         int64_t min_int = llvm::minIntN(to_width);
427         int64_t max_int = llvm::maxIntN(to_width);
428         auto zero_int = llvm::ConstantInt::get(to_ir_type, 0);
429         auto min_value_int = llvm::ConstantInt::get(to_ir_type, min_int);
430         auto max_value_int = llvm::ConstantInt::get(to_ir_type, max_int);
431         auto min_value_float = llvm::ConstantFP::get(from_ir_type, min_int);
432         auto max_value_float = llvm::ConstantFP::get(from_ir_type, max_int);
433         auto clamped = FPToSI(operand_value,
434                               llvm_ir::PrimitiveTypeToIrType(to_type, module_));
435         // x <= static_cast<float>(INT_MIN) ? INT_MIN : ...
436         clamped = Select(FCmpOLE(operand_value, min_value_float), min_value_int,
437                          clamped);
438         // x >= static_cast<float>(INT_MAX) ? INT_MAX : ...
439         clamped = Select(FCmpOGE(operand_value, max_value_float), max_value_int,
440                          clamped);
441         // isnan(x) ? 0 : ...
442         clamped =
443             Select(FCmpUNO(operand_value, operand_value), zero_int, clamped);
444         return clamped;
445       }
446       if (primitive_util::IsUnsignedIntegralType(to_type)) {
447         uint64_t min_int = 0;
448         uint64_t max_int = llvm::maxUIntN(to_width);
449         auto min_value_int = llvm::ConstantInt::get(to_ir_type, min_int);
450         auto max_value_int = llvm::ConstantInt::get(to_ir_type, max_int);
451         auto min_value_float = llvm::ConstantFP::get(from_ir_type, min_int);
452         auto max_value_float = llvm::ConstantFP::get(from_ir_type, max_int);
453         auto clamped = FPToUI(operand_value,
454                               llvm_ir::PrimitiveTypeToIrType(to_type, module_));
455         // (x <= 0.0 || isnan(x)) ? 0 : ...
456         clamped = Select(FCmpULE(operand_value, min_value_float), min_value_int,
457                          clamped);
458         // x >= static_cast<float>(UINT_MAX) ? UINT_MAX : ...
459         clamped = Select(FCmpOGE(operand_value, max_value_float), max_value_int,
460                          clamped);
461         return clamped;
462       }
463       return Unimplemented("unhandled conversion operation: %s => %s",
464                            PrimitiveType_Name(from_type),
465                            PrimitiveType_Name(to_type));
466     }
467     case HloOpcode::kBitcastConvert: {
468       PrimitiveType from_type = op->operand(0)->shape().element_type();
469       PrimitiveType to_type = op->shape().element_type();
470       CHECK(primitive_util::IsFloatingPointType(from_type));
471       if (from_type == to_type) {
472         return operand_value;
473       }
474       if (primitive_util::BitWidth(from_type) ==
475           primitive_util::BitWidth(to_type)) {
476         return BitCast(operand_value,
477                        llvm_ir::PrimitiveTypeToIrType(to_type, module_));
478       }
479       return InvalidArgument(
480           "bitcast conversion from primitive type %s to %s with unequal "
481           "bit-widths (%u versus %u) ",
482           PrimitiveType_Name(from_type), PrimitiveType_Name(to_type),
483           primitive_util::BitWidth(from_type),
484           primitive_util::BitWidth(to_type));
485     }
486     case HloOpcode::kExp:
487       return EmitExp(op->shape().element_type(), operand_value, "");
488     case HloOpcode::kExpm1:
489       return EmitExpm1(op->shape().element_type(), operand_value);
490     case HloOpcode::kLog:
491       return EmitLog(op->shape().element_type(), operand_value);
492     case HloOpcode::kLog1p:
493       return EmitLog1p(op->shape().element_type(), operand_value);
494     case HloOpcode::kCos:
495       return EmitCos(op->shape().element_type(), operand_value);
496     case HloOpcode::kSin:
497       return EmitSin(op->shape().element_type(), operand_value);
498     case HloOpcode::kTanh:
499       return EmitTanh(op->shape().element_type(), operand_value);
500     case HloOpcode::kSqrt:
501       return EmitSqrt(op->shape().element_type(), operand_value);
502     case HloOpcode::kRsqrt:
503       return EmitRsqrt(op->shape().element_type(), operand_value);
504     case HloOpcode::kCbrt:
505       return EmitCbrt(op->shape().element_type(), operand_value);
506     case HloOpcode::kFloor:
507       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor,
508                                           {operand_value},
509                                           {operand_value->getType()}, b_);
510     case HloOpcode::kCeil:
511       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ceil,
512                                           {operand_value},
513                                           {operand_value->getType()}, b_);
514     case HloOpcode::kAbs:
515       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
516                                           {operand_value},
517                                           {operand_value->getType()}, b_);
518     case HloOpcode::kRoundNearestAfz:
519       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::round,
520                                           {operand_value},
521                                           {operand_value->getType()}, b_);
522     // TODO(b/238238423): llvm::Intrinsic::nearbyint is equivalent to roundeven
523     // as TF and JAX default to FE_TONEAREST. Call llvm::Intrinsic::roundeven
524     // instead once GPU emitter supports lowering LLVM.
525     case HloOpcode::kRoundNearestEven:
526       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nearbyint,
527                                           {operand_value},
528                                           {operand_value->getType()}, b_);
529     case HloOpcode::kSign: {
530       auto type = operand_value->getType();
531       auto zero = llvm::ConstantFP::get(type, 0.0);
532       auto ne0_i1 = FCmpONE(operand_value, zero);
533       auto ne0_float = UIToFP(ne0_i1, type);
534       llvm::Value* result = llvm_ir::EmitCallToIntrinsic(
535           llvm::Intrinsic::copysign, {ne0_float, operand_value},
536           {operand_value->getType()}, b_);
537       auto is_nan = FCmpUNO(operand_value, operand_value);
538       result = Select(is_nan, operand_value, result);
539       return result;
540     }
541     case HloOpcode::kIsFinite: {
542       // abs(x) o!= inf, this works because the comparison returns false if
543       // either operand is NaN.
544       auto type = operand_value->getType();
545       auto abs_value = llvm_ir::EmitCallToIntrinsic(
546           llvm::Intrinsic::fabs, {operand_value}, {type}, b_);
547       auto infinity = llvm::ConstantFP::getInfinity(type);
548       auto not_infinite = FCmpONE(abs_value, infinity);
549       return b_->CreateZExt(not_infinite,
550                             llvm_ir::PrimitiveTypeToIrType(PRED, module_));
551     }
552     case HloOpcode::kNegate:
553       return FNeg(operand_value);
554     case HloOpcode::kReal:
555       return operand_value;
556     case HloOpcode::kImag:
557       return llvm::ConstantFP::get(operand_value->getType(), 0.0);
558     default:
559       return Unimplemented("unary floating-point op '%s'",
560                            HloOpcodeString(op->opcode()));
561   }
562 }
563 
EmitComplexUnaryOp(const HloInstruction * op,llvm::Value * operand_value)564 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
565     const HloInstruction* op, llvm::Value* operand_value) {
566   PrimitiveType input_type = op->operand(0)->shape().element_type();
567   PrimitiveType component_type =
568       primitive_util::IsComplexType(input_type)
569           ? primitive_util::ComplexComponentType(input_type)
570           : input_type;
571   switch (op->opcode()) {
572     case HloOpcode::kLog: {
573       return EmitComplexLog(op, operand_value);
574     }
575     case HloOpcode::kLog1p: {
576       // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
577       // log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1)
578       // log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1)
579       auto a = EmitExtractReal(operand_value);
580       auto b = EmitExtractImag(operand_value);
581       llvm::Type* llvm_ty = a->getType();
582       auto one = llvm::ConstantFP::get(llvm_ty, 1.0);
583       auto two = llvm::ConstantFP::get(llvm_ty, 2.0);
584       auto a_plus_one = FAdd(a, one);
585       auto sum_sq = FAdd(FAdd(FMul(a, a), FMul(two, a)), FMul(b, b));
586       TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog1p(component_type, sum_sq));
587       TF_ASSIGN_OR_RETURN(auto angle,
588                           EmitAtan2(component_type, b, a_plus_one, ""));
589       auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
590       return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
591     }
592     case HloOpcode::kConvert: {
593       PrimitiveType from_type = op->operand(0)->shape().element_type();
594       TF_RET_CHECK(primitive_util::IsComplexType(from_type));
595       PrimitiveType to_type = op->shape().element_type();
596       TF_RET_CHECK(primitive_util::IsComplexType(to_type));
597       if (from_type == to_type) {
598         return operand_value;
599       }
600       PrimitiveType to_component_type =
601           primitive_util::ComplexComponentType(to_type);
602       auto to_ir_component_type =
603           llvm_ir::PrimitiveTypeToIrType(to_component_type, module_);
604       return EmitComposeComplex(
605           op, FPCast(EmitExtractReal(operand_value), to_ir_component_type),
606           FPCast(EmitExtractImag(operand_value), to_ir_component_type));
607     }
608     case HloOpcode::kExp: {
609       // e^(a+bi) = e^a*(cos(b)+sin(b)i)
610       TF_ASSIGN_OR_RETURN(
611           auto exp_a,
612           EmitExp(component_type, EmitExtractReal(operand_value), ""));
613       TF_ASSIGN_OR_RETURN(
614           auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
615       TF_ASSIGN_OR_RETURN(
616           auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
617       return EmitComposeComplex(op, FMul(exp_a, cos_b), FMul(exp_a, sin_b));
618     }
619     case HloOpcode::kExpm1: {
620       // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
621       TF_ASSIGN_OR_RETURN(
622           auto exp_a,
623           EmitExp(component_type, EmitExtractReal(operand_value), ""));
624       TF_ASSIGN_OR_RETURN(
625           auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
626       TF_ASSIGN_OR_RETURN(
627           auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
628       auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0);
629       auto real_result = FSub(FMul(exp_a, cos_b), one);
630       auto imag_result = FMul(exp_a, sin_b);
631       return EmitComposeComplex(op, real_result, imag_result);
632     }
633     case HloOpcode::kCos: {
634       // cos(z) = .5(e^(iz) + e^(-iz))
635       // cos(a+bi) = .5(e^(-b+ai) + e^(b-ai))
636       // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
637       // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(-a)+sin(-a)i))
638       // cos(-x) = cos(x) and sin(-x) = -sin(x), so
639       // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i))
640       //           = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b))
641       auto a = EmitExtractReal(operand_value);
642       auto b = EmitExtractImag(operand_value);
643       auto type = a->getType();
644       TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b, ""));
645       auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
646       auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
647       TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
648       TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
649       return EmitComposeComplex(op,
650                                 FMul(cos_a, FAdd(half_exp_neg_b, half_exp_b)),
651                                 FMul(sin_a, FSub(half_exp_neg_b, half_exp_b)));
652     }
653     case HloOpcode::kSin: {
654       // sin(z) = .5i(e^(-iz) - e^(iz))
655       // sin(a+bi) = .5i(e^(-i(a+bi)) - e^(i(a+bi)))
656       //           = .5i(e^(b-ai) - e^(-b+ai))
657       // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
658       // sin(a+bi) = 0.5i(e^b*(cos(-a)+sin(-a)i) - e^-b*(cos(a)+sin(a)i))
659       //           = 0.5(e^b*(cos(-a)i-sin(-a)) - e^-b*(cos(a)i-sin(a)))
660       // cos(-x) = cos(x) and sin(-x) = -sin(x), so
661       //           = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a)))
662       //           = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b)
663       auto a = EmitExtractReal(operand_value);
664       auto b = EmitExtractImag(operand_value);
665       auto type = a->getType();
666       TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b, ""));
667       auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
668       auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
669       TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
670       TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
671       return EmitComposeComplex(op,
672                                 FMul(sin_a, FAdd(half_exp_b, half_exp_neg_b)),
673                                 FMul(cos_a, FSub(half_exp_b, half_exp_neg_b)));
674     }
675     case HloOpcode::kTanh: {
676       /*
677       tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x))
678       e^(a+bi) = e^a*(cos(b)+sin(b)i)
679       so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) /
680               (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a))
681       cos(b)=cos(-b), sin(-b)=-sin(b)
682       so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) /
683               (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a))
684              =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) /
685               (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a))
686              =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) /
687               (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a))
688       This is a complex division, so we can multiply by denom_conj/denom_conj
689              =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) *
690               (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) /
691               ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
692              =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) +
693                i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) /
694               ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
695              =(e^(2a)-e^(-2a) +
696                i*[cos(b)sin(b)(e^(2a)+2+e^(-2a))-cos(b)sin(b)(e^(2a)-2+e^(2a)))]
697                / (cos(b)^2*(e^(2a)+2+e^(-2a)) + sin(b)^2*(e^(2a)-2+e^(2a))
698              =(e^(2a)-e^(-2a) +
699                i*cos(b)sin(b)*[e^(2a)+2+e^(-2a)-e^(2a)+2-e^(-2a)]) /
700                ([cos(b)^2 + sin(b)^2][e^(2a)+e^(-2a)])+2*[cos(b)^2 - sin(b)^2])
701              =(e^(2a)-e^(-2a) + i*cos(b)sin(b)*4) /
702               (e^(2a)+e^(-2a)+2*[cos(b)^2 - sin(b)^2])
703              =(e^(2a)-e^(-2a) + i*[sin(2b)/2]*4) /
704               (e^(2a)+e^(-2a)+2*[cos(2b)])
705              =(e^(2a)-e^(-2a) + i*2*sin(2b)) / (e^(2a) + e^(-2a) + 2*cos(2b))
706       */
707       llvm::Value* a = EmitExtractReal(operand_value);
708       llvm::Value* b = EmitExtractImag(operand_value);
709 
710       llvm::Type* type = a->getType();
711 
712       llvm::Value* neg_one = llvm::ConstantFP::get(type, -1.F);
713       llvm::Value* two_a = FAdd(a, a);
714       llvm::Value* neg_2a = FMul(neg_one, two_a);
715 
716       // When we are calculating the real numerator, e^(2a)-e^(-2a), for small
717       // values of `a`, we will get a ULP of 2^-23 using the exp function. Using
718       // expm1 to calculate e^(2a)-e^(-2a) = [e^(2a)-1] - [e^(-2a)-1] allows our
719       // ULP to be arbitrarily small. For larger values of `a`, calculating the
720       // numerator as Exp(2a)-Exp(-2a) vs Expm1(2a)-Expm1(-2a) return virtually
721       // identical results.
722       TF_ASSIGN_OR_RETURN(llvm::Value * exp_2a_m1,
723                           EmitExpm1(component_type, two_a));
724       TF_ASSIGN_OR_RETURN(llvm::Value * exp_neg_2a_m1,
725                           EmitExpm1(component_type, neg_2a));
726       llvm::Value* real_numerator = FSub(exp_2a_m1, exp_neg_2a_m1);
727 
728       // We can use the identity cos(2b)+1 = cos(b)^2-sin(b)^2+cos(b)^2+sin(b)^2
729       // = 2cos(b)^2. This gives us the ability to be more precise when the
730       // denominator is close to zero.
731       TF_ASSIGN_OR_RETURN(llvm::Value * cos_b, EmitCos(component_type, b));
732       llvm::Value* four = llvm::ConstantFP::get(type, 4.F);
733       llvm::Value* cos_b_sq = FMul(cos_b, cos_b);
734       llvm::Value* two_cos_2b_p2 = FMul(cos_b_sq, four);
735 
736       // Similarly we can compute sin(2b) with the formula sin(2b) =
737       // 2*sin(b)*cos(b).
738       TF_ASSIGN_OR_RETURN(llvm::Value * sin_b, EmitSin(component_type, b));
739       llvm::Value* imag_numerator = FMul(four, FMul(cos_b, sin_b));
740 
741       // Expm1(x) is about x for small values of x, but exp_sum_m2 is about x^2
742       // for small value of x. As a result, due to floating point precision
743       // issues, x^2 is a better approximation than Expm1(x) + Expm1(x) for
744       // small values of x.
745       llvm::Value* a_sqr = FMul(a, a);
746       llvm::Value* use_approx_cutoff = llvm::ConstantFP::get(type, 1e-8);
747       llvm::Value* use_approx = FCmpOLT(a_sqr, use_approx_cutoff);
748 
749       llvm::Value* exp_sum_m2 =
750           Select(use_approx, a_sqr, FAdd(exp_2a_m1, exp_neg_2a_m1));
751       llvm::Value* denom = FAdd(exp_sum_m2, two_cos_2b_p2);
752 
753       // As `a` grows toward +inf and -inf, the real numerator will grow towards
754       // +inf and -inf respectively, while the denominator will always grow
755       // towards +inf. The result is real_numerator/denom = NaN, when it should
756       // equal +1 and -1 respectively. Therefore, if our denominator is +inf,
757       // we just hardcode the limits for the real numbers.
758       llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
759       llvm::Value* is_inf = FCmpOEQ(exp_sum_m2, inf);
760       llvm::Value* real_limit = llvm_ir::EmitCallToIntrinsic(
761           llvm::Intrinsic::copysign, {neg_one, a}, {type}, b_);
762 
763       llvm::Value* real =
764           Select(is_inf, real_limit, FDiv(real_numerator, denom));
765       llvm::Value* imag = FDiv(imag_numerator, denom);
766 
767       // The complex tanh functions have a few corner cases:
768       // 1. (+0, +0) => (+0, +0)        - Handled normally
769       // 2. (x, +Inf) => (NaN, NaN)     - See below
770       // 3. (x, NaN) => (NaN, NaN)      - See below
771       // 4. (+inf, y) => (1, +0)        - Handled normally
772       // 5. (+Inf, +Inf) => (1, +/-0)   - See below
773       // 6. (+Inf, NaN) => (1, +/-0)    - See below
774       // 7. (NaN, +0) => (NaN, +0)      - See below
775       // 8. (NaN, y) => (NaN, NaN)      - Handled normally
776       // 9. (NaN, NaN) => (NaN, NaN)    - Handled normally
777       //
778       // For the cases that aren't handled normally:
779       // 2/3) Part of the calculation we do is that if exp(a) + exp(-a) = +inf,
780       //      then we return (+/-1, +/-0). However, this is only true if we
781       //      assume that a is infinity or b is finite. In the event that both a
782       //      is finite and b is either +/-Inf or NaN, then our normal
783       //      calculation would end up returing (+/-1, NaN), as opposed to (NaN,
784       //      NaN).
785       // 5/6) We always calculate the imaginary value as sin(2b)/denominator.
786       //      When the denominator is infinity, this assures us that the zero is
787       //      the correct sign. However if our imaginary input results in
788       //      sin(2b) = NaN, we calculate our imaginary result as NaN.
789       // 7)   In the event that a is NaN, the denominator will be NaN.
790       //      Therefore, the normal calculation gives (NaN, NaN) while we need
791       //      (NaN, +0).
792       if (!(b_->getFastMathFlags().noNaNs() &&
793             b_->getFastMathFlags().noInfs())) {
794         llvm::Value* abs_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
795                                                           {a}, {type}, b_);
796         llvm::Value* zero = llvm::ConstantFP::get(type, 0.F);
797         llvm::Value* nan = llvm::ConstantFP::getNaN(type);
798 
799         llvm::Value* a_is_inf = FCmpOEQ(abs_a, inf);
800         llvm::Value* b_is_zero = FCmpOEQ(b, zero);
801 
802         // imag_numerator = 2sin(2b), so sin(2b) is NaN if and only if
803         // imag_numerator is NaN.
804         llvm::Value* sin_2b_is_nan =
805             b_->CreateFCmpUNO(imag_numerator, imag_numerator);
806 
807         llvm::Value* real_is_nan =
808             b_->CreateAnd(sin_2b_is_nan, b_->CreateNot(a_is_inf));
809         llvm::Value* imag_is_zero =
810             b_->CreateOr(b_is_zero, b_->CreateAnd(a_is_inf, sin_2b_is_nan));
811 
812         real = Select(real_is_nan, nan, real);
813         imag = Select(imag_is_zero, zero, imag);
814       }
815 
816       return EmitComposeComplex(op, real, imag);
817     }
818     case HloOpcode::kAbs: {
819       return EmitComplexAbs(component_type, operand_value);
820     }
821     case HloOpcode::kSign: {  // Sign(c) = c / |c|
822       TF_ASSIGN_OR_RETURN(auto cplx_abs,
823                           EmitComplexAbs(component_type, operand_value));
824       auto type = cplx_abs->getType();
825       auto zero = llvm::ConstantFP::get(type, 0.0);
826       auto oeq = FCmpOEQ(cplx_abs, zero);
827       return Select(
828           oeq, EmitComposeComplex(op, zero, zero),
829           EmitComposeComplex(op, FDiv(EmitExtractReal(operand_value), cplx_abs),
830                              FDiv(EmitExtractImag(operand_value), cplx_abs)));
831     }
832     case HloOpcode::kSqrt: {
833       return EmitComplexSqrt(op, component_type, operand_value);
834     }
835     case HloOpcode::kRsqrt: {
836       return EmitComplexRsqrt(op, component_type, operand_value);
837     }
838     case HloOpcode::kCbrt: {
839       return EmitComplexCbrt(op, component_type, operand_value);
840     }
841     case HloOpcode::kNegate:
842       return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)),
843                                 FNeg(EmitExtractImag(operand_value)));
844     case HloOpcode::kReal:
845       return EmitExtractReal(operand_value);
846     case HloOpcode::kImag:
847       return EmitExtractImag(operand_value);
848     default:
849       return Unimplemented("unary complex op '%s'",
850                            HloOpcodeString(op->opcode()));
851   }
852 }
853 
EmitBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)854 StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
855     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
856   PrimitiveType operand_type = op->operand(0)->shape().element_type();
857   if (operand_type == PRED) {
858     return EmitPredBinaryOp(op, lhs_value, rhs_value);
859   } else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape())) {
860     return EmitIntegerBinaryOp(
861         op, lhs_value, rhs_value,
862         primitive_util::IsSignedIntegralType(operand_type));
863   } else if (primitive_util::IsComplexType(operand_type)) {
864     return EmitComplexBinaryOp(op, lhs_value, rhs_value);
865   } else {
866     return EmitFloatBinaryOp(op, lhs_value, rhs_value);
867   }
868 }
869 
EmitFloatBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)870 StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
871     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
872   switch (op->opcode()) {
873     case HloOpcode::kComplex:
874       return EmitComposeComplex(op, lhs_value, rhs_value);
875     case HloOpcode::kAdd:
876       return FAdd(lhs_value, rhs_value, op->name());
877     case HloOpcode::kSubtract:
878       return FSub(lhs_value, rhs_value, op->name());
879     case HloOpcode::kMultiply:
880       return FMul(lhs_value, rhs_value, op->name());
881     case HloOpcode::kDivide:
882       return FDiv(lhs_value, rhs_value, op->name());
883     case HloOpcode::kRemainder:
884       return FRem(lhs_value, rhs_value, op->name());
885     // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
886     // comparisons always return false when one of the operands is NaN, whereas
887     // unordered comparisons return true.
888     //
889     // We use ordered comparisons for everything except kNe, where we use an
890     // unordered comparison.  This makes x != y equivalent to !(x == y), and
891     // matches C++'s semantics.
892     case HloOpcode::kCompare: {
893       PrimitiveType operand_type = op->operand(0)->shape().element_type();
894       if (operand_type == BF16) {
895         lhs_value = EmitBF16ToF32(lhs_value, b_);
896         rhs_value = EmitBF16ToF32(rhs_value, b_);
897       }
898       switch (op->comparison_direction()) {
899         case ComparisonDirection::kEq:
900           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
901                                          rhs_value, b_, op->name());
902         case ComparisonDirection::kNe:
903           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
904                                          rhs_value, b_, op->name());
905         case ComparisonDirection::kLt:
906           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
907                                          rhs_value, b_, op->name());
908         case ComparisonDirection::kGt:
909           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
910                                          rhs_value, b_, op->name());
911         case ComparisonDirection::kLe:
912           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
913                                          rhs_value, b_, op->name());
914         case ComparisonDirection::kGe:
915           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
916                                          rhs_value, b_, op->name());
917       }
918     }
919     case HloOpcode::kMaximum:
920       return EmitFloatMax(lhs_value, rhs_value, op->name());
921     case HloOpcode::kMinimum:
922       return EmitFloatMin(lhs_value, rhs_value, op->name());
923     case HloOpcode::kPower:
924       return EmitPow(op->shape().element_type(), lhs_value, rhs_value,
925                      op->name());
926     case HloOpcode::kAtan2:
927       return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value,
928                        op->name());
929     default:
930       return Unimplemented("binary floating point op '%s'",
931                            HloOpcodeString(op->opcode()));
932   }
933 }
934 
935 // Using sqrt(a^2 + b^2) can cause overflow errors. Therefore we can use
936 // sqrt(a^2 + b^2) = sqrt(a^2 * (1 + b^2/a^2))
937 //                 = |a| * sqrt(1 + (b/a)^2)
938 // With the assumption that |a| >= |b|.
939 //
940 // This method returns the min, max, and sqrt term for this calculation. This is
941 // done to prevent potential overflow errors that can occur from multiplying the
942 // max with the sqrt term. (i.e. when calculating the sqrt of the absolute
943 // value, we can take the sqrt of the max and the sqrt term before multiplying
944 // them together.) If return_sqrt is false, it returns 1 + (b/a)^2 instead of
945 // sqrt(1 + (b/a)^2).
946 StatusOr<std::tuple<llvm::Value*, llvm::Value*, llvm::Value*>>
EmitComplexAbsHelper(PrimitiveType prim_type,llvm::Value * operand_value,bool return_sqrt)947 ElementalIrEmitter::EmitComplexAbsHelper(PrimitiveType prim_type,
948                                          llvm::Value* operand_value,
949                                          bool return_sqrt) {
950   llvm::Value* real = EmitExtractReal(operand_value);
951   llvm::Value* imag = EmitExtractImag(operand_value);
952   llvm::Value* abs_real = llvm_ir::EmitCallToIntrinsic(
953       llvm::Intrinsic::fabs, {real}, {real->getType()}, b_);
954   llvm::Value* abs_imag = llvm_ir::EmitCallToIntrinsic(
955       llvm::Intrinsic::fabs, {imag}, {imag->getType()}, b_);
956   llvm::Value* max = EmitFloatMax(abs_real, abs_imag, "");
957   llvm::Value* min = EmitFloatMin(abs_real, abs_imag, "");
958 
959   llvm::Value* div = FDiv(min, max);
960   llvm::Value* div_sq = FMul(div, div);
961   llvm::Value* one = llvm::ConstantFP::get(max->getType(), 1);
962   llvm::Value* one_p_div_sq = FAdd(one, div_sq);
963   TF_ASSIGN_OR_RETURN(llvm::Value * sqrt, EmitSqrt(prim_type, one_p_div_sq));
964   return std::make_tuple(min, max, return_sqrt ? sqrt : one_p_div_sq);
965 }
966 
EmitComplexAbs(PrimitiveType prim_type,llvm::Value * operand_value)967 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexAbs(
968     PrimitiveType prim_type, llvm::Value* operand_value) {
969   llvm::Value* min;
970   llvm::Value* max;
971   llvm::Value* sqrt;
972   TF_ASSIGN_OR_RETURN(
973       std::tie(min, max, sqrt),
974       EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/true));
975   llvm::Value* result = FMul(max, sqrt);
976   // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
977   // In such cases, we return `min` instead of `result`.
978   return Select(FCmpUNO(result, result), min, result);
979 }
980 
981 // Calculates ComplexAbs in the same way, except using:
982 // sqrt(|a| * sqrt(1 + (b/a)^2)) = sqrt(|a|) * pow(1 + (b/a)^2, .25)
EmitSqrtComplexAbs(PrimitiveType prim_type,llvm::Value * operand_value)983 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrtComplexAbs(
984     PrimitiveType prim_type, llvm::Value* operand_value) {
985   llvm::Value* min;
986   llvm::Value* max;
987   llvm::Value* one_p_div_sq;
988   TF_ASSIGN_OR_RETURN(
989       std::tie(min, max, one_p_div_sq),
990       EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/false));
991   TF_ASSIGN_OR_RETURN(llvm::Value * sqrt_max, EmitSqrt(prim_type, max));
992   TF_ASSIGN_OR_RETURN(llvm::Value * pow,
993                       EmitPow(prim_type, one_p_div_sq,
994                               llvm::ConstantFP::get(max->getType(), .25), ""));
995   llvm::Value* result = FMul(sqrt_max, pow);
996   // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
997   // In such cases, we return `min` instead of `result`.
998   return Select(FCmpUNO(result, result), min, result);
999 }
1000 
1001 // Calculates ComplexAbs in the same way, except using:
1002 // rsqrt(|a| * sqrt(1 + (b/a)^2)) = rsqrt(|a|) * rsqrt(sqrt(1 + (b/a)^2))
EmitRsqrtComplexAbs(PrimitiveType prim_type,llvm::Value * operand_value)1003 StatusOr<llvm::Value*> ElementalIrEmitter::EmitRsqrtComplexAbs(
1004     PrimitiveType prim_type, llvm::Value* operand_value) {
1005   llvm::Value* min;
1006   llvm::Value* max;
1007   llvm::Value* sqrt;
1008   TF_ASSIGN_OR_RETURN(
1009       std::tie(min, max, sqrt),
1010       EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/true));
1011   TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_max, EmitRsqrt(prim_type, max));
1012   TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_sqrt, EmitRsqrt(prim_type, sqrt));
1013   llvm::Value* result = FMul(rsqrt_max, rsqrt_sqrt);
1014   TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_min, EmitRsqrt(prim_type, min));
1015   // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
1016   // In such cases, we return rsqrt(min) instead of `result`.
1017   return Select(FCmpUNO(result, result), rsqrt_min, result);
1018 }
1019 
EmitComplexAdd(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1020 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexAdd(
1021     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1022   return EmitComposeComplex(
1023       op, FAdd(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
1024       FAdd(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
1025 }
1026 
EmitComplexSubtract(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1027 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexSubtract(
1028     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1029   return EmitComposeComplex(
1030       op, FSub(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
1031       FSub(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
1032 }
1033 
EmitComplexMultiply(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1034 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexMultiply(
1035     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1036   return EmitComposeComplex(
1037       op,
1038       FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
1039            FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))),
1040       FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
1041            FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))));
1042 }
1043 
EmitComplexDivide(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1044 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexDivide(
1045     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1046   // Division of complex numbers is implemented here, taking into account
1047   // over/underflow, NaN and Inf values.
1048   auto a_r = EmitExtractReal(lhs_value);
1049   auto a_i = EmitExtractImag(lhs_value);
1050   auto b_r = EmitExtractReal(rhs_value);
1051   auto b_i = EmitExtractImag(rhs_value);
1052   auto type = a_r->getType();
1053 
1054   // Smith's algorithm to divide complex numbers. It is just a bit smarter
1055   // way to compute the following formula:
1056   //  (a_r + a_i * i) / (b_r + b_i * i)
1057   //    = (a_r + a_i * i) (b_r - b_i * i) / ((b_r + b_i * i)(b_r - b_i * i))
1058   //    = ((a_r * b_r + a_i * b_i) + (a_i * b_r - a_r * b_i) * i) / ||b||^2
1059   //
1060   // Depending on whether |b_r| < |b_i| we compute either
1061   //   b_r_b_i_ratio = b_r / b_i
1062   //   b_r_b_i_denom = b_i + b_r * b_r_b_i_ratio
1063   //   c_r = (a_r * b_r_b_i_ratio + a_i ) / b_r_b_i_denom
1064   //   c_i = (a_i * b_r_b_i_ratio - a_r ) / b_r_b_i_denom
1065   //
1066   // or
1067   //
1068   //   b_i_b_r_ratio = b_i / b_r
1069   //   b_i_b_r_denom = b_r + b_i * b_i_b_r_ratio
1070   //   c_r = (a_r + a_i * b_i_b_r_ratio ) / b_i_b_r_denom
1071   //   c_i = (a_i - a_r * b_i_b_r_ratio ) / b_i_b_r_denom
1072   //
1073   // See https://dl.acm.org/citation.cfm?id=368661 for more details.
1074   auto b_r_b_i_ratio = FDiv(b_r, b_i);
1075   auto b_r_b_i_denom = FAdd(b_i, FMul(b_r_b_i_ratio, b_r));
1076   auto b_i_b_r_ratio = FDiv(b_i, b_r);
1077   auto b_i_b_r_denom = FAdd(b_r, FMul(b_i_b_r_ratio, b_i));
1078 
1079   auto b_r_abs =
1080       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {b_r}, {type}, b_);
1081   auto b_i_abs =
1082       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {b_i}, {type}, b_);
1083   auto b_r_lt_b_i = FCmpOLT(b_r_abs, b_i_abs);
1084   auto c_r = Select(b_r_lt_b_i,
1085                     FDiv(FAdd(FMul(b_r_b_i_ratio, a_r), a_i), b_r_b_i_denom),
1086                     FDiv(FAdd(FMul(b_i_b_r_ratio, a_i), a_r), b_i_b_r_denom));
1087   auto c_i = Select(b_r_lt_b_i,
1088                     FDiv(FSub(FMul(b_r_b_i_ratio, a_i), a_r), b_r_b_i_denom),
1089                     FDiv(FSub(a_i, FMul(b_i_b_r_ratio, a_r)), b_i_b_r_denom));
1090   auto result = EmitComposeComplex(op, c_r, c_i);
1091 
1092   // Consider corner cases, if the result is (NaN, NaN).
1093   auto zero = llvm::ConstantFP::get(type, 0.0);
1094   auto one = llvm::ConstantFP::get(type, 1.0);
1095   auto inf = llvm::ConstantFP::getInfinity(type);
1096 
1097   // Case 1. Zero denominator.
1098   auto zero_denominator =
1099       And(And(FCmpOEQ(b_r_abs, zero), FCmpOEQ(b_i_abs, zero)),
1100           Or(Not(FCmpUNO(a_r, zero)), Not(FCmpUNO(a_i, zero))));
1101   auto inf_with_sign_of_b_r = llvm_ir::EmitCallToIntrinsic(
1102       llvm::Intrinsic::copysign, {inf, b_r}, {type}, b_);
1103   auto zero_denominator_result = EmitComposeComplex(
1104       op, FMul(inf_with_sign_of_b_r, a_r), FMul(inf_with_sign_of_b_r, a_i));
1105 
1106   // Case 2. Infinite numerator, finite denominator.
1107   auto b_r_finite = FCmpONE(b_r_abs, inf);
1108   auto b_i_finite = FCmpONE(b_i_abs, inf);
1109   auto a_r_abs =
1110       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {a_r}, {type}, b_);
1111   auto a_i_abs =
1112       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {a_i}, {type}, b_);
1113   auto a_r_infinite = FCmpOEQ(a_r_abs, inf);
1114   auto a_i_infinite = FCmpOEQ(a_i_abs, inf);
1115   auto inf_num_finite_denom =
1116       And(Or(a_r_infinite, a_i_infinite), And(b_r_finite, b_i_finite));
1117 
1118   auto a_r_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1119       llvm::Intrinsic::copysign, {Select(a_r_infinite, one, zero), a_r}, {type},
1120       b_);
1121   auto a_i_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1122       llvm::Intrinsic::copysign, {Select(a_i_infinite, one, zero), a_i}, {type},
1123       b_);
1124   auto inf_num_finite_denom_result = EmitComposeComplex(
1125       op,
1126       FMul(inf,
1127            FAdd(FMul(a_r_inf_with_sign, b_r), FMul(a_i_inf_with_sign, b_i))),
1128       FMul(inf,
1129            FSub(FMul(a_i_inf_with_sign, b_r), FMul(a_r_inf_with_sign, b_i))));
1130 
1131   // Case 3. Finite numerator, infinite denominator.
1132   auto a_r_finite = FCmpONE(a_r_abs, inf);
1133   auto a_i_finite = FCmpONE(a_i_abs, inf);
1134   auto b_r_infinite = FCmpOEQ(b_r_abs, inf);
1135   auto b_i_infinite = FCmpOEQ(b_i_abs, inf);
1136   auto finite_num_inf_denom =
1137       And(Or(b_r_infinite, b_i_infinite), And(a_r_finite, a_i_finite));
1138 
1139   auto b_r_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1140       llvm::Intrinsic::copysign, {Select(b_r_infinite, one, zero), b_r}, {type},
1141       b_);
1142   auto b_i_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1143       llvm::Intrinsic::copysign, {Select(b_i_infinite, one, zero), b_i}, {type},
1144       b_);
1145   auto finite_num_inf_denom_result = EmitComposeComplex(
1146       op,
1147       FMul(zero,
1148            FAdd(FMul(a_r, b_r_inf_with_sign), FMul(a_i, b_i_inf_with_sign))),
1149       FMul(zero,
1150            FSub(FMul(a_i, b_r_inf_with_sign), FMul(a_r, b_i_inf_with_sign))));
1151 
1152   auto c_nan = And(FCmpUNO(c_r, zero), FCmpUNO(c_i, zero));
1153   return Select(c_nan,
1154                 Select(zero_denominator, zero_denominator_result,
1155                        Select(inf_num_finite_denom, inf_num_finite_denom_result,
1156                               Select(finite_num_inf_denom,
1157                                      finite_num_inf_denom_result, result))),
1158                 result);
1159 }
1160 
EmitComplexLog(const HloInstruction * op,llvm::Value * operand_value)1161 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexLog(
1162     const HloInstruction* op, llvm::Value* operand_value) {
1163   // log(a+bi) = log(abs(a+bi)) + i*atan2(b,a)
1164   PrimitiveType component_type =
1165       primitive_util::ComplexComponentType(op->shape().element_type());
1166   auto a = EmitExtractReal(operand_value);
1167   auto b = EmitExtractImag(operand_value);
1168   TF_ASSIGN_OR_RETURN(llvm::Value * angle, EmitAtan2(component_type, b, a, ""));
1169   TF_ASSIGN_OR_RETURN(llvm::Value * abs,
1170                       EmitComplexAbs(component_type, operand_value));
1171   TF_ASSIGN_OR_RETURN(llvm::Value * log_abs, EmitLog(component_type, abs));
1172   return EmitComposeComplex(op, log_abs, angle);
1173 }
1174 
1175 // Using our EmitComplexPower formula, but setting c=0.5 and d=0, we get:
1176 //   e^[ln(r)*c - t*d] * [cos(ln(r)*d + t*c) + i*sin(ln(r)*d + t*c)]
1177 // = e^[ln(r)*0.5] * [cos(t*0.5) + i*sin(t*0.5)]
1178 // = r^0.5 * [cos(t/2) + i*sin(t/2)]
1179 // = sqrt(r) * [cos(t/2) + i*sin(t/2)]
1180 // where r = |a+bi| and t = atan2(b,a)
1181 // TODO(bixia): See doc for implementation without atan2.
EmitComplexSqrt(const HloInstruction * op,PrimitiveType prim_type,llvm::Value * operand_value)1182 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexSqrt(
1183     const HloInstruction* op, PrimitiveType prim_type,
1184     llvm::Value* operand_value) {
1185   llvm::Type* type = static_cast<llvm::StructType*>(operand_value->getType())
1186                          ->getElementType(0);
1187 
1188   TF_ASSIGN_OR_RETURN(llvm::Value * r,
1189                       EmitSqrtComplexAbs(prim_type, operand_value));
1190 
1191   llvm::Value* a = EmitExtractReal(operand_value);
1192   llvm::Value* b = EmitExtractImag(operand_value);
1193   TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a, ""));
1194 
1195   llvm::Value* c = llvm::ConstantFP::get(type, 0.5);
1196   llvm::Value* angle = FMul(t, c);
1197   TF_ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle));
1198   TF_ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle));
1199 
1200   llvm::Value* real_part;
1201   llvm::Value* imag_part;
1202 
1203   llvm::Value* zero = llvm::ConstantFP::get(type, 0);
1204 
1205   if (!(b_->getFastMathFlags().noNaNs() && b_->getFastMathFlags().noInfs())) {
1206     llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
1207     llvm::Value* neg_inf = llvm::ConstantFP::getInfinity(type, true);
1208     llvm::Value* nan = llvm::ConstantFP::getNaN(type);
1209     llvm::Value* abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
1210                                                       {b}, {b->getType()}, b_);
1211 
1212     real_part = Select(Or(FCmpOEQ(abs_b, inf), FCmpOEQ(a, inf)), inf,
1213                        Select(And(FCmpOEQ(a, neg_inf), FCmpONE(abs_b, inf)),
1214                               zero, FMul(r, cos)));
1215 
1216     llvm::Value* b_signed_inf = llvm_ir::EmitCallToIntrinsic(
1217         llvm::Intrinsic::copysign, {inf, b}, {b->getType()}, b_);
1218     imag_part =
1219         Select(Or(FCmpOEQ(abs_b, inf), FCmpOEQ(a, neg_inf)), b_signed_inf,
1220                Select(FCmpUNO(r, r), nan,
1221                       Select(FCmpOEQ(sin, zero), sin, FMul(r, sin))));
1222   } else {
1223     real_part = FMul(r, cos);
1224     imag_part = Select(FCmpOEQ(sin, zero), sin, FMul(r, sin));
1225   }
1226 
1227   return Select(FCmpOEQ(r, zero), EmitComposeComplex(op, zero, zero),
1228                 EmitComposeComplex(op, real_part, imag_part));
1229 }
1230 
1231 // Similar to Sqrt, we can use our EmitComplexPower formula, but set
1232 // c=-0.5 and d=0. We get:
1233 //   e^[ln(r)*c - t*d] * [cos(ln(r)*d + t*c) + i*sin(ln(r)*d + t*c)]
1234 // = e^[ln(r)*-0.5] * [cos(t*-0.5) + i*sin(t*-0.5)]
1235 // = r^(-0.5) * [cos(-t/2) + i*sin(-t/2)]
1236 // = rsqrt(r) * [cos(-t/2) + i*sin(-t/2)]
1237 // where r = |a+bi| and t = atan2(b,a).
EmitComplexRsqrt(const HloInstruction * op,PrimitiveType prim_type,llvm::Value * operand_value)1238 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexRsqrt(
1239     const HloInstruction* op, PrimitiveType prim_type,
1240     llvm::Value* operand_value) {
1241   llvm::Type* type = static_cast<llvm::StructType*>(operand_value->getType())
1242                          ->getElementType(0);
1243 
1244   TF_ASSIGN_OR_RETURN(llvm::Value * r,
1245                       EmitRsqrtComplexAbs(prim_type, operand_value));
1246 
1247   llvm::Value* a = EmitExtractReal(operand_value);
1248   llvm::Value* b = EmitExtractImag(operand_value);
1249   TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a, ""));
1250 
1251   llvm::Value* c = llvm::ConstantFP::get(type, -0.5);
1252   llvm::Value* angle = FMul(t, c);
1253   TF_ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle));
1254   TF_ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle));
1255 
1256   llvm::Value* real_part = FMul(r, cos);
1257   llvm::Value* imag_part = FMul(r, sin);
1258 
1259   if (!(b_->getFastMathFlags().noNaNs() && b_->getFastMathFlags().noInfs())) {
1260     llvm::Value* zero = llvm::ConstantFP::get(type, 0);
1261     llvm::Value* neg_one = llvm::ConstantFP::get(type, -1);
1262     llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
1263     llvm::Value* nan = llvm::ConstantFP::getNaN(type);
1264     // llvm::Value* neg_inf = llvm::ConstantFP::getInfinity(type, true);
1265     llvm::Value* a_signed_zero = llvm_ir::EmitCallToIntrinsic(
1266         llvm::Intrinsic::copysign, {zero, a}, {a->getType()}, b_);
1267     llvm::Value* b_signed_zero = llvm_ir::EmitCallToIntrinsic(
1268         llvm::Intrinsic::copysign, {zero, b}, {b->getType()}, b_);
1269     llvm::Value* neg_b_signed_zero = FMul(b_signed_zero, neg_one);
1270 
1271     llvm::Value* abs_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
1272                                                       {a}, {a->getType()}, b_);
1273     llvm::Value* abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
1274                                                       {b}, {b->getType()}, b_);
1275 
1276     llvm::Value* is_zero_zero = And(FCmpOEQ(b, zero), FCmpOEQ(a, zero));
1277     real_part = Select(
1278         is_zero_zero, inf,
1279         Select(Or(And(FCmpOEQ(abs_b, inf), FCmpUNO(a, a)), FCmpOEQ(abs_a, inf)),
1280                a_signed_zero, FMul(r, cos)));
1281     imag_part = Select(
1282         is_zero_zero, nan,
1283         Select(Or(And(FCmpOEQ(abs_b, inf), FCmpUNO(a, a)), FCmpOEQ(abs_a, inf)),
1284                neg_b_signed_zero, FMul(r, sin)));
1285   } else {
1286     llvm::Value* zero = llvm::ConstantFP::get(type, 0);
1287     llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
1288     llvm::Value* nan = llvm::ConstantFP::getNaN(type);
1289 
1290     llvm::Value* is_zero_zero = And(FCmpOEQ(b, zero), FCmpOEQ(a, zero));
1291     real_part = Select(is_zero_zero, inf, FMul(r, cos));
1292     imag_part = Select(is_zero_zero, nan, FMul(r, sin));
1293   }
1294 
1295   return EmitComposeComplex(op, real_part, imag_part);
1296 }
1297 
1298 //
1299 // Using EmitComplexPower with c=1.0/3.0 and d=0
EmitComplexCbrt(const HloInstruction * op,PrimitiveType prim_type,llvm::Value * operand_value)1300 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexCbrt(
1301     const HloInstruction* op, PrimitiveType prim_type,
1302     llvm::Value* operand_value) {
1303   auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1304   auto third = llvm::ConstantFP::get(type, 1.0 / 3.0);
1305   auto zero = llvm::ConstantFP::get(type, 0);
1306   llvm::Value* a = EmitExtractReal(operand_value);
1307   llvm::Value* b = EmitExtractImag(operand_value);
1308   return EmitComplexPower(op, a, b, third, zero);
1309 }
1310 
1311 // (a+bi)^(c+di) =
1312 //    (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
1313 //    where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
EmitComplexPower(const HloInstruction * op,llvm::Value * a,llvm::Value * b,llvm::Value * c,llvm::Value * d)1314 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexPower(
1315     const HloInstruction* op, llvm::Value* a, llvm::Value* b, llvm::Value* c,
1316     llvm::Value* d) {
1317   PrimitiveType component_type =
1318       primitive_util::ComplexComponentType(op->shape().element_type());
1319   auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b));
1320   auto zero = llvm::ConstantFP::get(a->getType(), 0);
1321   auto one_half = llvm::ConstantFP::get(a->getType(), 0.5);
1322   auto one = llvm::ConstantFP::get(a->getType(), 1);
1323   auto half_c = FMul(one_half, c);
1324 
1325   TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c,
1326                       EmitPow(component_type, aa_p_bb, half_c, ""));
1327 
1328   auto neg_d = FNeg(d);
1329   TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a, ""));
1330   auto neg_d_arg_lhs = FMul(neg_d, arg_lhs);
1331   TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs,
1332                       EmitExp(component_type, neg_d_arg_lhs, ""));
1333   auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
1334   TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb));
1335   auto half_d = FMul(one_half, d);
1336   auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb));
1337   TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q));
1338   TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q));
1339   // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
1340   // Branch Cuts for Complex Elementary Functions or Much Ado About
1341   // Nothing's Sign Bit, W. Kahan, Section 10.
1342   return Select(
1343       And(And(FCmpOEQ(aa_p_bb, zero), FCmpOEQ(d, zero)), FCmpOLE(zero, c)),
1344       EmitComposeComplex(op, Select(FCmpOEQ(zero, c), one, zero), zero),
1345       EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q)));
1346 }
1347 
EmitComplexBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1348 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
1349     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1350   switch (op->opcode()) {
1351     case HloOpcode::kAdd:
1352       return EmitComplexAdd(op, lhs_value, rhs_value);
1353     case HloOpcode::kSubtract:
1354       return EmitComplexSubtract(op, lhs_value, rhs_value);
1355     case HloOpcode::kMultiply:
1356       return EmitComplexMultiply(op, lhs_value, rhs_value);
1357     case HloOpcode::kDivide: {
1358       return EmitComplexDivide(op, lhs_value, rhs_value);
1359     }
1360     // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
1361     // comparisons always return false when one of the operands is NaN, whereas
1362     // unordered comparisons return true.
1363     //
1364     // We use ordered comparisons for everything except kNe, where we use an
1365     // unordered comparison.  This makes x != y equivalent to !(x == y), and
1366     // matches C++'s semantics.
1367     case HloOpcode::kCompare: {
1368       switch (op->comparison_direction()) {
1369         case ComparisonDirection::kEq:
1370           return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
1371                                              EmitExtractReal(lhs_value),
1372                                              EmitExtractReal(rhs_value), b_),
1373                      llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
1374                                              EmitExtractImag(lhs_value),
1375                                              EmitExtractImag(rhs_value), b_));
1376         case ComparisonDirection::kNe:
1377           return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
1378                                             EmitExtractReal(lhs_value),
1379                                             EmitExtractReal(rhs_value), b_),
1380                     llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
1381                                             EmitExtractImag(lhs_value),
1382                                             EmitExtractImag(rhs_value), b_));
1383         default:
1384           return Unimplemented(
1385               "complex comparison '%s'",
1386               ComparisonDirectionToString(op->comparison_direction()));
1387       }
1388     }
1389     case HloOpcode::kPower: {
1390       auto a = EmitExtractReal(lhs_value);
1391       auto b = EmitExtractImag(lhs_value);
1392       auto c = EmitExtractReal(rhs_value);
1393       auto d = EmitExtractImag(rhs_value);
1394       return EmitComplexPower(op, a, b, c, d);
1395     }
1396     case HloOpcode::kAtan2: {
1397       // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
1398       auto y = lhs_value;
1399       auto x = rhs_value;
1400       TF_ASSIGN_OR_RETURN(auto x_squared, EmitComplexMultiply(op, x, x));
1401       TF_ASSIGN_OR_RETURN(auto y_squared, EmitComplexMultiply(op, y, y));
1402       TF_ASSIGN_OR_RETURN(auto x_squared_plus_y_squared,
1403                           EmitComplexAdd(op, x_squared, y_squared));
1404       auto component_type =
1405           primitive_util::ComplexComponentType(op->shape().element_type());
1406       TF_ASSIGN_OR_RETURN(
1407           auto sqrt_x_squared_plus_y_squared,
1408           EmitComplexSqrt(op, component_type, x_squared_plus_y_squared));
1409       auto type = llvm_ir::PrimitiveTypeToIrType(component_type, module_);
1410       auto zero = llvm::ConstantFP::get(type, 0.0);
1411       auto one = llvm::ConstantFP::get(type, 1.0);
1412       auto i = EmitComposeComplex(op, zero, one);
1413       TF_ASSIGN_OR_RETURN(auto i_times_y, EmitComplexMultiply(op, i, y));
1414       TF_ASSIGN_OR_RETURN(auto x_plus_iy, EmitComplexAdd(op, x, i_times_y));
1415       TF_ASSIGN_OR_RETURN(
1416           auto div_result,
1417           EmitComplexDivide(op, x_plus_iy, sqrt_x_squared_plus_y_squared));
1418       TF_ASSIGN_OR_RETURN(auto log_result, EmitComplexLog(op, div_result));
1419       auto negative_one = llvm::ConstantFP::get(type, -1.0);
1420       auto negative_i = EmitComposeComplex(op, zero, negative_one);
1421       return EmitComplexMultiply(op, negative_i, log_result);
1422     }
1423     default:
1424       return Unimplemented("binary complex op '%s'",
1425                            HloOpcodeString(op->opcode()));
1426   }
1427 }
1428 
EmitFloatMax(llvm::Value * lhs_value,llvm::Value * rhs_value,absl::string_view name)1429 llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
1430                                               llvm::Value* rhs_value,
1431                                               absl::string_view name) {
1432   return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max(), name);
1433 }
1434 
EmitFloatMin(llvm::Value * lhs_value,llvm::Value * rhs_value,absl::string_view name)1435 llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
1436                                               llvm::Value* rhs_value,
1437                                               absl::string_view name) {
1438   return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max(), name);
1439 }
1440 
EmitLog(PrimitiveType prim_type,llvm::Value * value)1441 StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
1442                                                    llvm::Value* value) {
1443   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value},
1444                                       {value->getType()}, b_);
1445 }
1446 
EmitLog1p(PrimitiveType prim_type,llvm::Value * value)1447 StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
1448                                                      llvm::Value* value) {
1449   auto x = value;
1450   auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1451   auto one = llvm::ConstantFP::get(type, 1.0);
1452   auto negative_half = llvm::ConstantFP::get(type, -0.5);
1453   // When x is large, the naive evaluation of ln(x + 1) is more
1454   // accurate than the Taylor series.
1455   TF_ASSIGN_OR_RETURN(auto for_large_x, EmitLog(prim_type, FAdd(x, one)));
1456   // When x is small, (defined to be less than sqrt(2) / 2), use a rational
1457   // approximation. The approximation below is based on one from the Cephes
1458   // Mathematical Library.
1459   //
1460   // sqrt(2) - 1.
1461   const auto kAntilogarithmIsSmallThreshold = 0.41421356237309504880;
1462 
1463   static const std::array<double, 7> kDenominatorCoeffs{
1464       1.,
1465       1.5062909083469192043167E1,
1466       8.3047565967967209469434E1,
1467       2.2176239823732856465394E2,
1468       3.0909872225312059774938E2,
1469       2.1642788614495947685003E2,
1470       6.0118660497603843919306E1,
1471   };
1472 
1473   static const std::array<double, 7> kNumeratorCoeffs{
1474       4.5270000862445199635215E-5, 4.9854102823193375972212E-1,
1475       6.5787325942061044846969E0,  2.9911919328553073277375E1,
1476       6.0949667980987787057556E1,  5.7112963590585538103336E1,
1477       2.0039553499201281259648E1,
1478   };
1479 
1480   auto x_squared = FMul(x, x);
1481   TF_ASSIGN_OR_RETURN(auto denominator,
1482                       EvaluatePolynomial(type, x, kDenominatorCoeffs));
1483   TF_ASSIGN_OR_RETURN(auto numerator,
1484                       EvaluatePolynomial(type, x, kNumeratorCoeffs));
1485   auto for_small_x = FDiv(numerator, denominator);
1486   for_small_x = FMul(FMul(x, x_squared), for_small_x);
1487   for_small_x = FAdd(FMul(negative_half, x_squared), for_small_x);
1488   for_small_x = FAdd(x, for_small_x);
1489 
1490   auto abs_x =
1491       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
1492   auto x_is_small = FCmpOLT(
1493       abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold));
1494   return Select(x_is_small, for_small_x, for_large_x);
1495 }
1496 
EmitSqrt(PrimitiveType,llvm::Value * value)1497 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrt(PrimitiveType,
1498                                                     llvm::Value* value) {
1499   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {value},
1500                                       {value->getType()}, b_);
1501 }
1502 
EmitRsqrt(PrimitiveType prim_type,llvm::Value * value)1503 StatusOr<llvm::Value*> ElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type,
1504                                                      llvm::Value* value) {
1505   TF_ASSIGN_OR_RETURN(auto sqrt, EmitSqrt(prim_type, value));
1506   return FDiv(llvm::ConstantFP::get(sqrt->getType(), 1.0), sqrt);
1507 }
1508 
EmitSin(PrimitiveType prim_type,llvm::Value * value)1509 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSin(PrimitiveType prim_type,
1510                                                    llvm::Value* value) {
1511   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value},
1512                                       {value->getType()}, b_);
1513 }
1514 
EmitCos(PrimitiveType prim_type,llvm::Value * value)1515 StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type,
1516                                                    llvm::Value* value) {
1517   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value},
1518                                       {value->getType()}, b_);
1519 }
1520 
EmitExp(PrimitiveType prim_type,llvm::Value * value,absl::string_view name)1521 StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
1522                                                    llvm::Value* value,
1523                                                    absl::string_view name) {
1524   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value},
1525                                       {value->getType()}, b_, name);
1526 }
1527 
EmitExpm1(PrimitiveType prim_type,llvm::Value * value)1528 StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
1529                                                      llvm::Value* value) {
1530   auto x = value;
1531   auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1532   auto one = llvm::ConstantFP::get(type, 1.0);
1533   auto half = llvm::ConstantFP::get(type, 0.5);
1534   auto zero = llvm::ConstantFP::get(type, 0.0);
1535 
1536   // expm1(x) == tanh(x/2)*(exp(x)+1)
1537   // x/2 can underflow, if it does we approximate expm1 with x.
1538   auto x_over_two = FMul(x, half);
1539   auto x_over_two_is_zero = FCmpOEQ(x_over_two, zero);
1540   auto abs_x =
1541       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {x}, {type}, b_);
1542   // Use a naive exp(x)-1 calculation if |x| is > 0.5
1543   auto x_magnitude_is_large = FCmpOGT(abs_x, half);
1544   TF_ASSIGN_OR_RETURN(auto tanh_of_x_over_two, EmitTanh(prim_type, x_over_two));
1545   TF_ASSIGN_OR_RETURN(auto exp_of_x, EmitExp(prim_type, x, ""));
1546   auto exp_of_x_plus_one = FAdd(exp_of_x, one);
1547   auto exp_of_x_minus_one = FSub(exp_of_x, one);
1548   auto expm1_of_x = FMul(tanh_of_x_over_two, exp_of_x_plus_one);
1549   expm1_of_x = Select(x_magnitude_is_large, exp_of_x_minus_one, expm1_of_x);
1550   expm1_of_x = Select(x_over_two_is_zero, x, expm1_of_x);
1551   return expm1_of_x;
1552 }
1553 
EmitPow(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs,absl::string_view name)1554 StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
1555                                                    llvm::Value* lhs,
1556                                                    llvm::Value* rhs,
1557                                                    absl::string_view name) {
1558   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs},
1559                                       {lhs->getType()}, b_, name);
1560 }
1561 
EmitCbrt(PrimitiveType prim_type,llvm::Value * value)1562 StatusOr<llvm::Value*> ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type,
1563                                                     llvm::Value* value) {
1564   auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1565   auto third = llvm::ConstantFP::get(type, 1.0 / 3.0);
1566   auto abs_value =
1567       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
1568   TF_ASSIGN_OR_RETURN(llvm::Value * abs_res,
1569                       EmitPow(prim_type, abs_value, third, ""));
1570   auto signed_res = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
1571                                                  {abs_res, value}, {type}, b_);
1572   return signed_res;
1573 }
1574 
EmitAtan2(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value *,absl::string_view)1575 StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(
1576     PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* /*rhs*/,
1577     absl::string_view /*name*/) {
1578   return Unimplemented("atan2");
1579 }
1580 
EmitTanh(PrimitiveType prim_type,llvm::Value * value)1581 StatusOr<llvm::Value*> ElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
1582                                                     llvm::Value* value) {
1583   return Unimplemented("tanh");
1584 }
1585 
EmitReducePrecision(const HloInstruction * hlo,llvm::Value * x)1586 StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
1587     const HloInstruction* hlo, llvm::Value* x) {
1588   return EmitReducePrecisionIR(
1589       /*src_ty=*/hlo->operand(0)->shape().element_type(), x,
1590       /*dest_exponent_bits=*/hlo->exponent_bits(),
1591       /*dest_mantissa_bits=*/hlo->mantissa_bits(),
1592       /*quiet_nans=*/false, b_);
1593 }
1594 
SaturateShiftIfNecessary(llvm::IRBuilder<> * b,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * shift_result,bool saturate_to_sign_bit)1595 static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b,
1596                                              llvm::Value* lhs, llvm::Value* rhs,
1597                                              llvm::Value* shift_result,
1598                                              bool saturate_to_sign_bit) {
1599   llvm::IntegerType* integer_type =
1600       llvm::cast<llvm::IntegerType>(lhs->getType());
1601   unsigned integer_bitsize = integer_type->getBitWidth();
1602   llvm::ConstantInt* integer_bitsize_constant =
1603       llvm::ConstantInt::get(integer_type, integer_bitsize);
1604   llvm::ConstantInt* zero = llvm::ConstantInt::get(integer_type, 0);
1605   llvm::ConstantInt* minus_one = llvm::ConstantInt::get(integer_type, -1);
1606   llvm::Value* saturated_value;
1607   if (saturate_to_sign_bit) {
1608     saturated_value =
1609         b->CreateSelect(b->CreateICmpSLT(lhs, zero), minus_one, zero);
1610   } else {
1611     saturated_value = zero;
1612   }
1613   llvm::Value* shift_amt_in_range =
1614       b->CreateICmpULT(rhs, integer_bitsize_constant, "shft.chk");
1615   return b->CreateSelect(shift_amt_in_range, shift_result, saturated_value);
1616 }
1617 
GetOne(llvm::Type * type)1618 llvm::Value* ElementalIrEmitter::GetOne(llvm::Type* type) {
1619   return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 1);
1620 }
1621 
GetZero(llvm::Type * type)1622 llvm::Value* ElementalIrEmitter::GetZero(llvm::Type* type) {
1623   return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 0);
1624 }
1625 
GetIntSMin(llvm::Type * type)1626 llvm::Value* ElementalIrEmitter::GetIntSMin(llvm::Type* type) {
1627   auto* integer_type = llvm::cast<llvm::IntegerType>(type);
1628   return llvm::ConstantInt::get(integer_type, llvm::APInt::getSignedMinValue(
1629                                                   integer_type->getBitWidth()));
1630 }
1631 
GetMinusOne(llvm::Type * type)1632 llvm::Value* ElementalIrEmitter::GetMinusOne(llvm::Type* type) {
1633   auto* integer_type = llvm::cast<llvm::IntegerType>(type);
1634   return llvm::ConstantInt::get(
1635       integer_type, llvm::APInt::getAllOnesValue(integer_type->getBitWidth()));
1636 }
1637 
IsZero(llvm::Value * v)1638 llvm::Value* ElementalIrEmitter::IsZero(llvm::Value* v) {
1639   return ICmpEQ(v, llvm::ConstantInt::get(v->getType(), 0));
1640 }
1641 
IsIntMinDivisionOverflow(llvm::Value * lhs,llvm::Value * rhs)1642 llvm::Value* ElementalIrEmitter::IsIntMinDivisionOverflow(llvm::Value* lhs,
1643                                                           llvm::Value* rhs) {
1644   return And(ICmpEQ(lhs, GetIntSMin(lhs->getType())),
1645              ICmpEQ(rhs, GetMinusOne(rhs->getType())));
1646 }
1647 
EmitIntegerDivide(llvm::Value * lhs,llvm::Value * rhs,bool is_signed)1648 llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs,
1649                                                    llvm::Value* rhs,
1650                                                    bool is_signed) {
1651   // Integer division overflow behavior:
1652   //
1653   // X / 0 == -1
1654   // INT_SMIN /s -1 = INT_SMIN
1655 
1656   if (!is_signed) {
1657     llvm::Value* udiv_is_unsafe = IsZero(rhs);
1658     llvm::Value* safe_rhs = Select(udiv_is_unsafe, GetOne(lhs->getType()), rhs);
1659     llvm::Value* safe_div = UDiv(lhs, safe_rhs);
1660     return Select(udiv_is_unsafe, GetMinusOne(lhs->getType()), safe_div);
1661   }
1662 
1663   llvm::Value* has_zero_divisor = IsZero(rhs);
1664   llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
1665   llvm::Value* sdiv_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
1666   llvm::Value* safe_rhs = Select(sdiv_is_unsafe, GetOne(lhs->getType()), rhs);
1667   llvm::Value* safe_div = SDiv(lhs, safe_rhs);
1668 
1669   return Select(
1670       has_zero_divisor, GetMinusOne(lhs->getType()),
1671       Select(has_int_min_overflow, GetIntSMin(lhs->getType()), safe_div));
1672 }
1673 
EmitIntegerRemainder(llvm::Value * lhs,llvm::Value * rhs,bool is_signed)1674 llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs,
1675                                                       llvm::Value* rhs,
1676                                                       bool is_signed) {
1677   // Integer remainder overflow behavior:
1678   //
1679   // X % 0 == X
1680   // INT_SMIN %s -1 = 0
1681 
1682   if (!is_signed) {
1683     llvm::Value* urem_is_unsafe = IsZero(rhs);
1684     llvm::Value* safe_rhs = Select(urem_is_unsafe, GetOne(lhs->getType()), rhs);
1685     llvm::Value* safe_rem = URem(lhs, safe_rhs);
1686     return Select(urem_is_unsafe, lhs, safe_rem);
1687   }
1688 
1689   llvm::Value* has_zero_divisor = IsZero(rhs);
1690   llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
1691   llvm::Value* srem_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
1692   llvm::Value* safe_rhs = Select(srem_is_unsafe, GetOne(lhs->getType()), rhs);
1693   llvm::Value* safe_rem = SRem(lhs, safe_rhs);
1694 
1695   return Select(
1696       has_zero_divisor, lhs,
1697       Select(has_int_min_overflow, GetZero(lhs->getType()), safe_rem));
1698 }
1699 
EmitIntegerPow(llvm::Value * base,llvm::Value * exponent,bool is_signed)1700 llvm::Value* ElementalIrEmitter::EmitIntegerPow(llvm::Value* base,
1701                                                 llvm::Value* exponent,
1702                                                 bool is_signed) {
1703   // Exponentiation by squaring:
1704   // https://en.wikipedia.org/wiki/Exponentiation_by_squaring;
1705   int bits = 6;  // Everything else would overflow for any exponent > 1, as 2^64
1706                  // is the larget possible exponent for a 64-bit integer, and
1707                  // that's 1 << 6.
1708   llvm::Value* accumulator = llvm::ConstantInt::get(base->getType(), 1);
1709   llvm::Value* one = llvm::ConstantInt::get(exponent->getType(), 1);
1710   llvm::Value* zero = llvm::ConstantInt::get(exponent->getType(), 0);
1711   llvm::Value* original_base = base;
1712   llvm::Value* original_exponent = exponent;
1713 
1714   // Unroll the loop at compile time.
1715   for (int i = 0; i < bits; i++) {
1716     accumulator =
1717         b_->CreateSelect(b_->CreateICmpEQ(b_->CreateAnd(exponent, one), one),
1718                          b_->CreateMul(accumulator, base), accumulator);
1719     base = b_->CreateMul(base, base);
1720     exponent = b_->CreateLShr(exponent, 1);
1721   }
1722   return b_->CreateSelect(
1723       b_->CreateICmpSGE(original_exponent, zero), accumulator,
1724       b_->CreateSelect(b_->CreateICmpEQ(original_base, one), one, zero));
1725 }
1726 
EmitPredBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1727 StatusOr<llvm::Value*> ElementalIrEmitter::EmitPredBinaryOp(
1728     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1729   // Per the reference interpreter, pred arithmetic should behave like
1730   // `int8_t(x) OP int8_t(y) != 0`.  For most permitted ops, we can just emit
1731   // the underlying i8 op to achieve this (e.g. kAnd, kOr, kXor, kMultiply).  In
1732   // the case of kAdd, we would need to insert a comparison instruction after
1733   // the addition, but it's both easier and faster to emit a bitwise or
1734   // instruction instead.
1735   //
1736   // For several of these ops, a faster bitwise implementation is available, but
1737   // LLVM is unlikely to be able to see it, since it gets IR that e.g. loads i8s
1738   // from memory, multiplies them, and writes the result back, without any
1739   // indication that the inputs were assumed to be 0 or 1.  So, just in case,
1740   // help it out by choosing the faster instruction to begin with.
1741   switch (op->opcode()) {
1742     case HloOpcode::kCompare:
1743     case HloOpcode::kXor:
1744       return EmitIntegerBinaryOp(op, lhs_value, rhs_value, false);
1745 
1746     // zext(i1 x) + zext(i1 y) != 0 === or(x, y)
1747     // max(zext(i1 x), zext(i1 y)) != 0 === or(x, y)
1748     case HloOpcode::kAdd:
1749     case HloOpcode::kMaximum:
1750     case HloOpcode::kOr:
1751       return Or(lhs_value, rhs_value);
1752 
1753     // zext(i1 x) * zext(i1 y) != 0 === and(x, y)
1754     // min(zext(i1 x), zext(i1 y)) != 0 === and(x, y)
1755     case HloOpcode::kMultiply:
1756     case HloOpcode::kMinimum:
1757     case HloOpcode::kAnd:
1758       return And(lhs_value, rhs_value);
1759 
1760     // These opcodes are rejected by shape-inference for PRED elements; calling
1761     // them out here serves more as documentation than a necessary check.
1762     case HloOpcode::kDivide:
1763     case HloOpcode::kRemainder:
1764     case HloOpcode::kPower:
1765     case HloOpcode::kSubtract:
1766     case HloOpcode::kShiftLeft:
1767     case HloOpcode::kShiftRightArithmetic:
1768     case HloOpcode::kShiftRightLogical:
1769       return InternalError("Invalid binary op '%s' for pred",
1770                            HloOpcodeString(op->opcode()));
1771 
1772     default:
1773       return Unimplemented("binary pred op '%s'",
1774                            HloOpcodeString(op->opcode()));
1775   }
1776 }
1777 
EmitIntegerBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1778 StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
1779     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value,
1780     bool is_signed) {
1781   switch (op->opcode()) {
1782     // TODO(jingyue): add the "nsw" attribute for signed types.
1783     case HloOpcode::kAdd:
1784       return Add(lhs_value, rhs_value);
1785     case HloOpcode::kSubtract:
1786       return Sub(lhs_value, rhs_value);
1787     case HloOpcode::kMultiply:
1788       return Mul(lhs_value, rhs_value);
1789     case HloOpcode::kDivide:
1790       return EmitIntegerDivide(lhs_value, rhs_value, is_signed);
1791     case HloOpcode::kRemainder:
1792       return EmitIntegerRemainder(lhs_value, rhs_value, is_signed);
1793     case HloOpcode::kCompare: {
1794       switch (op->comparison_direction()) {
1795         case ComparisonDirection::kEq:
1796           return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
1797                                          rhs_value, b_);
1798         case ComparisonDirection::kNe:
1799           return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value,
1800                                          rhs_value, b_);
1801         case ComparisonDirection::kLt:
1802           return llvm_ir::EmitComparison(
1803               is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT,
1804               lhs_value, rhs_value, b_);
1805         case ComparisonDirection::kGt:
1806           return llvm_ir::EmitComparison(
1807               is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT,
1808               lhs_value, rhs_value, b_);
1809         case ComparisonDirection::kLe:
1810           return llvm_ir::EmitComparison(
1811               is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE,
1812               lhs_value, rhs_value, b_);
1813         case ComparisonDirection::kGe:
1814           return llvm_ir::EmitComparison(
1815               is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
1816               lhs_value, rhs_value, b_);
1817       }
1818     }
1819     case HloOpcode::kMinimum:
1820       return EmitIntegralMin(lhs_value, rhs_value, is_signed);
1821     case HloOpcode::kMaximum:
1822       return EmitIntegralMax(lhs_value, rhs_value, is_signed);
1823     case HloOpcode::kAnd:
1824       return And(lhs_value, rhs_value);
1825     case HloOpcode::kOr:
1826       return Or(lhs_value, rhs_value);
1827     case HloOpcode::kPower:
1828       return EmitIntegerPow(lhs_value, rhs_value, is_signed);
1829     case HloOpcode::kXor:
1830       return Xor(lhs_value, rhs_value);
1831 
1832     // Shifting out bits >= the number of bits in the type being shifted
1833     // produces a poison value in LLVM which is basically "deferred undefined
1834     // behavior" -- doing something observable with such a value precipitates
1835     // UB.  We replace the poison value with a constant to avoid this deferred
1836     // UB.
1837     case HloOpcode::kShiftRightArithmetic:
1838       return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1839                                       AShr(lhs_value, rhs_value),
1840                                       /*saturate_to_sign_bit=*/true);
1841     case HloOpcode::kShiftLeft:
1842       return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1843                                       Shl(lhs_value, rhs_value),
1844                                       /*saturate_to_sign_bit=*/false);
1845     case HloOpcode::kShiftRightLogical:
1846       return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1847                                       LShr(lhs_value, rhs_value),
1848                                       /*saturate_to_sign_bit=*/false);
1849     default:
1850       return Unimplemented("binary integer op '%s'",
1851                            HloOpcodeString(op->opcode()));
1852   }
1853 }
1854 
EmitIntegralMax(llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1855 llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value,
1856                                                  llvm::Value* rhs_value,
1857                                                  bool is_signed) {
1858   return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE
1859                                          : llvm::ICmpInst::ICMP_UGE,
1860                                lhs_value, rhs_value),
1861                 lhs_value, rhs_value);
1862 }
1863 
EmitIntegralMin(llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1864 llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value,
1865                                                  llvm::Value* rhs_value,
1866                                                  bool is_signed) {
1867   return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE
1868                                          : llvm::ICmpInst::ICMP_ULE,
1869                                lhs_value, rhs_value),
1870                 lhs_value, rhs_value);
1871 }
1872 
EmitElementalSelect(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1873 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalSelect(
1874     const HloInstruction* hlo,
1875     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1876     const llvm_ir::IrArray::Index& index) {
1877   TF_ASSIGN_OR_RETURN(llvm::Value * pred_value,
1878                       operand_to_generator.at(hlo->operand(0))(index));
1879   TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value,
1880                       operand_to_generator.at(hlo->operand(1))(index));
1881   TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value,
1882                       operand_to_generator.at(hlo->operand(2))(index));
1883   return Select(Trunc(pred_value, b_->getInt1Ty()), on_true_value,
1884                 on_false_value);
1885 }
1886 
EmitElementalClamp(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1887 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalClamp(
1888     const HloInstruction* hlo,
1889     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1890     const llvm_ir::IrArray::Index& index) {
1891   TF_ASSIGN_OR_RETURN(llvm::Value * min_value,
1892                       operand_to_generator.at(hlo->operand(0))(index));
1893   TF_ASSIGN_OR_RETURN(llvm::Value * arg_value,
1894                       operand_to_generator.at(hlo->operand(1))(index));
1895   TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
1896                       operand_to_generator.at(hlo->operand(2))(index));
1897   PrimitiveType prim_type = hlo->shape().element_type();
1898   if (primitive_util::IsFloatingPointType(prim_type)) {
1899     return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value, ""), "");
1900   } else if (primitive_util::IsIntegralType(prim_type)) {
1901     bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
1902     return EmitIntegralMin(
1903         max_value, EmitIntegralMax(min_value, arg_value, is_signed), is_signed);
1904   } else {
1905     return Unimplemented("Clamp unimplemented for %s",
1906                          PrimitiveType_Name(prim_type));
1907   }
1908 }
1909 
EmitElementalConcatenate(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & source_index)1910 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
1911     const HloInstruction* hlo,
1912     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1913     const llvm_ir::IrArray::Index& source_index) {
1914   const int64_t concat_dim = hlo->dimensions(0);
1915   llvm::BasicBlock* init_block = b_->GetInsertBlock();
1916 
1917   llvm::BasicBlock* exit_block;
1918   if (b_->GetInsertPoint() != init_block->end()) {
1919     // Inserting into the middle.
1920     CHECK(init_block->getTerminator());
1921     exit_block =
1922         init_block->splitBasicBlock(b_->GetInsertPoint(), IrName(hlo, "merge"));
1923     init_block->getTerminator()->eraseFromParent();
1924   } else {
1925     // Inserting at the end.
1926     CHECK(!init_block->getTerminator());
1927     exit_block = llvm_ir::CreateBasicBlock(
1928         /*insert_before=*/nullptr, IrName(hlo, "merge"), b_);
1929   }
1930 
1931   llvm_ir::SetToFirstInsertPoint(exit_block, b_);
1932   llvm::PHINode* output = b_->CreatePHI(
1933       llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
1934       hlo->operands().size());
1935   auto prior_insert_point = b_->GetInsertPoint();
1936 
1937   b_->SetInsertPoint(init_block);
1938 
1939   // Assign a unique id for each *different* operand, and count how often each
1940   // operand is used. If all operands are different, the usage count will be 1
1941   // for each operand.
1942   absl::flat_hash_map<const HloInstruction*, int64_t> to_unique_operand_id;
1943   std::vector<int64_t> operand_usage_count;
1944   for (const HloInstruction* operand : hlo->operands()) {
1945     if (to_unique_operand_id.contains(operand)) {
1946       ++operand_usage_count[to_unique_operand_id[operand]];
1947     } else {
1948       int64_t unique_operand_id = to_unique_operand_id.size();
1949       to_unique_operand_id[operand] = unique_operand_id;
1950       operand_usage_count.push_back(1);
1951     }
1952   }
1953 
1954   // To avoid that we emit the same operand more than once, we create one basic
1955   // block for each *different* operand with a PHI node for the different source
1956   // index inputs.
1957   std::vector<llvm::BasicBlock*> emit_operand_blocks(
1958       to_unique_operand_id.size(), nullptr);
1959   std::vector<llvm::PHINode*> source_index_phis(to_unique_operand_id.size(),
1960                                                 nullptr);
1961   for (const HloInstruction* operand : hlo->operands()) {
1962     int64_t operand_id = to_unique_operand_id[operand];
1963     if (emit_operand_blocks[operand_id] != nullptr) {
1964       continue;
1965     }
1966 
1967     emit_operand_blocks[operand_id] = llvm_ir::CreateBasicBlock(
1968         exit_block, StrCat("concat_index_from_operand_id", operand_id), b_);
1969     auto saved_insert_point = b_->GetInsertPoint();
1970     llvm_ir::SetToFirstInsertPoint(emit_operand_blocks[operand_id], b_);
1971     source_index_phis[operand_id] =
1972         b_->CreatePHI(source_index.GetType(), operand_usage_count[operand_id]);
1973     std::vector<llvm::Value*> operand_multi_index = source_index.multidim();
1974     operand_multi_index[concat_dim] = b_->CreateNSWSub(
1975         operand_multi_index[concat_dim], source_index_phis[operand_id]);
1976 
1977     // Create the terminator of the block before calling operand generators,
1978     // because they require non-degenerate basic blocks.
1979     b_->SetInsertPoint(llvm::BranchInst::Create(
1980         exit_block, /*InsertAtEnd=*/emit_operand_blocks[operand_id]));
1981     llvm_ir::IrArray::Index operand_index(operand_multi_index, operand->shape(),
1982                                           source_index.GetType());
1983 
1984     TF_ASSIGN_OR_RETURN(llvm::Value * value,
1985                         operand_to_generator.at(operand)(operand_index));
1986     output->addIncoming(value, b_->GetInsertBlock());
1987     b_->SetInsertPoint(init_block, saved_insert_point);
1988   }
1989 
1990   // We use bisection to select the input operand.
1991   int64_t current_offset = 0;
1992 
1993   // Offset for every operand.
1994   std::vector<std::pair<int64_t, const HloInstruction*>> cases;
1995 
1996   cases.reserve(hlo->operand_count());
1997   for (const HloInstruction* operand : hlo->operands()) {
1998     cases.emplace_back(current_offset, operand);
1999     current_offset += operand->shape().dimensions(concat_dim);
2000   }
2001   CHECK_EQ(current_offset, hlo->shape().dimensions(concat_dim));
2002 
2003   std::function<llvm::BasicBlock*(
2004       absl::Span<const std::pair<int64_t, const HloInstruction*>> operands)>
2005       emit_tree =
2006           [&](absl::Span<const std::pair<int64_t, const HloInstruction*>>
2007                   operands) {
2008             llvm::IRBuilder<>::InsertPointGuard guard(*b_);
2009             size_t mid = operands.size() / 2;
2010             const std::pair<int64_t, const HloInstruction*>& pivot =
2011                 operands[mid];
2012             llvm::BasicBlock* block = llvm_ir::CreateBasicBlock(
2013                 exit_block,
2014                 absl::StrCat("concatenate.pivot.", pivot.first, "."), b_);
2015             b_->SetInsertPoint(block);
2016 
2017             // If there's only one element we're done. The range is contiguous
2018             // so we can just jump to the block for it.
2019             if (operands.size() == 1) {
2020               const std::pair<int64_t, const HloInstruction*>& operand =
2021                   operands.back();
2022               int64_t operand_id = to_unique_operand_id[operand.second];
2023 
2024               source_index_phis[operand_id]->addIncoming(
2025                   source_index.GetConstantWithIndexType(operand.first),
2026                   b_->GetInsertBlock());
2027               b_->CreateBr(emit_operand_blocks[operand_id]);
2028               return block;
2029             }
2030 
2031             // Take the middle element and recurse.
2032             llvm::Constant* pivot_const = llvm::ConstantInt::get(
2033                 source_index[concat_dim]->getType(), pivot.first);
2034             llvm::Value* comp =
2035                 b_->CreateICmpULT(source_index[concat_dim], pivot_const);
2036 
2037             llvm::BasicBlock* left_block = emit_tree(operands.subspan(0, mid));
2038             llvm::BasicBlock* right_block = emit_tree(operands.subspan(mid));
2039 
2040             b_->CreateCondBr(comp, left_block, right_block);
2041             return block;
2042           };
2043 
2044   Br(emit_tree(cases));
2045 
2046   b_->SetInsertPoint(exit_block, prior_insert_point);
2047   return output;
2048 }
2049 
EmitElementalDynamicSlice(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)2050 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
2051     const HloInstruction* hlo,
2052     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2053     const llvm_ir::IrArray::Index& index) {
2054   // Emit IR to read dynamic start indices from hlo->operand(1).
2055   const HloInstruction* input_hlo = hlo->operand(0);
2056   const int64_t rank = input_hlo->shape().rank();
2057   // Use the same index type for all tensor accesses in the same kernel.
2058   llvm::Type* index_type = index.GetType();
2059   std::vector<llvm::Value*> slice_start_multi_index(rank);
2060   for (int64_t i = 0; i < rank; ++i) {
2061     auto index_typed_const = [&](uint64_t c) -> llvm::Constant* {
2062       return llvm::ConstantInt::get(index_type, c);
2063     };
2064     llvm_ir::IrArray::Index zero_index(index_type);
2065     TF_ASSIGN_OR_RETURN(
2066         llvm::Value * start_index_value,
2067         operand_to_generator.at(hlo->operand(1 + i))(zero_index));
2068 
2069     // Clamp the start index so that the sliced portion fits in the operand:
2070     // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size)
2071     start_index_value = SExtOrTrunc(start_index_value, index_type);
2072     int64_t largest_valid_start_index =
2073         input_hlo->shape().dimensions(i) - hlo->shape().dimensions(i);
2074     CHECK_GE(largest_valid_start_index, 0);
2075 
2076     bool is_signed = ShapeUtil::ElementIsSigned(hlo->operand(1)->shape());
2077     start_index_value = EmitIntegralMin(
2078         index_typed_const(largest_valid_start_index),
2079         EmitIntegralMax(index_typed_const(0), start_index_value, is_signed),
2080         is_signed);
2081 
2082     start_index_value->setName(IrName(hlo, StrCat("start_idx", i)));
2083     slice_start_multi_index[i] = start_index_value;
2084   }
2085 
2086   std::vector<llvm::Value*> input_multi_index(rank);
2087   for (int64_t i = 0; i < rank; ++i) {
2088     // Emit IR which computes:
2089     //   input_index = start_index + offset_index
2090     input_multi_index[i] = Add(slice_start_multi_index[i], index[i]);
2091   }
2092   llvm_ir::IrArray::Index input_index(input_multi_index, input_hlo->shape(),
2093                                       index_type);
2094   return operand_to_generator.at(input_hlo)(input_index);
2095 }
2096 
EmitElementalGather(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)2097 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
2098     const HloInstruction* hlo,
2099     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2100     const llvm_ir::IrArray::Index& index) {
2101   const Shape& operand_shape = hlo->operand(0)->shape();
2102   const Shape& indices_shape = hlo->operand(1)->shape();
2103   const Shape& output_shape = hlo->shape();
2104 
2105   const GatherDimensionNumbers& dim_numbers = hlo->gather_dimension_numbers();
2106 
2107   const llvm_ir::ElementGenerator& operand_generator =
2108       operand_to_generator.at(hlo->operand(0));
2109   const llvm_ir::ElementGenerator& indices_generator =
2110       operand_to_generator.at(hlo->operand(1));
2111 
2112   llvm::Type* index_type = index.GetType();
2113   // This is the index into `operand` that holds the element we want to
2114   // generate.
2115   std::vector<llvm::Value*> operand_multi_index;
2116 
2117   // First copy in the window indices to operand_index. Also collect a mapping
2118   // from operand dimension to output window dimension. Elided window dimensions
2119   // map to -1.
2120   std::vector<int64_t> operand_to_output_dim(operand_shape.dimensions_size(),
2121                                              -1);
2122   for (int64_t i = 0, e = operand_shape.dimensions_size(),
2123                operand_index_dim = 0;
2124        i < e; i++) {
2125     if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
2126       operand_multi_index.push_back(index.GetConstantWithIndexType(0));
2127     } else {
2128       int64_t output_window_dim = dim_numbers.offset_dims(operand_index_dim++);
2129       operand_to_output_dim[i] = output_window_dim;
2130       operand_multi_index.push_back(index[output_window_dim]);
2131     }
2132   }
2133 
2134   // This is the index of the index vector in the start_indices tensor.
2135   std::vector<llvm::Value*> gather_index_index_components;
2136   {
2137     for (int64_t i = 0, e = output_shape.dimensions_size(); i < e; i++) {
2138       if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) {
2139         gather_index_index_components.push_back(index[i]);
2140       }
2141     }
2142 
2143     if (gather_index_index_components.size() !=
2144         indices_shape.dimensions_size()) {
2145       gather_index_index_components.insert(
2146           gather_index_index_components.begin() +
2147               dim_numbers.index_vector_dim(),
2148           nullptr);
2149     }
2150   }
2151 
2152   auto add_to_operand_index = [&](llvm::Value* index_component, int64_t dim) {
2153     auto index_component_type = index_component->getType();
2154     auto extended_type = index_component_type->getScalarSizeInBits() >=
2155                                  index_type->getScalarSizeInBits()
2156                              ? index_component_type
2157                              : index_type;
2158     // Possibly extend the value at the beginning to ensure clamping logic stays
2159     // in bounds.
2160     auto maybe_extended_index =
2161         index_component_type != extended_type
2162             ? b_->CreateSExt(index_component, extended_type)
2163             : index_component;
2164     int64_t operand_dim = dim_numbers.start_index_map(dim);
2165     int64_t output_dim = operand_to_output_dim[operand_dim];
2166     // If 'output_dim' is -1, it means 'operand_dim' is an elided window dim.
2167     // This means we set the iteration index to 0, so for the purpose of the
2168     // following calculations we can consider the output dimension size to be 1.
2169     int64_t output_dim_size =
2170         output_dim == -1 ? 1 : output_shape.dimensions(output_dim);
2171     int64_t largest_valid_start_index =
2172         operand_shape.dimensions(operand_dim) - output_dim_size;
2173     CHECK_GE(largest_valid_start_index, 0);
2174 
2175     // Clamp the gather index so that the gather region fits in the operand.
2176     // clamped_index =
2177     //     clamp(gather_dim_component_extended, 0, largest_valid_start_index);
2178     bool is_signed = ShapeUtil::ElementIsSigned(indices_shape);
2179     auto clamped_index = EmitIntegralMin(
2180         llvm::ConstantInt::get(extended_type, largest_valid_start_index),
2181         EmitIntegralMax(llvm::ConstantInt::get(extended_type, 0),
2182                         maybe_extended_index, is_signed),
2183         is_signed);
2184     // Truncate at the end to the optimized index size
2185     auto maybe_truncated_clamped_index = extended_type != index_type
2186                                              ? Trunc(clamped_index, index_type)
2187                                              : clamped_index;
2188 
2189     operand_multi_index[operand_dim] =
2190         Add(operand_multi_index[operand_dim], maybe_truncated_clamped_index);
2191   };
2192 
2193   if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) {
2194     IrArray::Index gather_index_index(gather_index_index_components,
2195                                       indices_shape, index_type);
2196     TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
2197                         indices_generator(gather_index_index));
2198     add_to_operand_index(gather_dim_component, 0);
2199   } else {
2200     int64_t index_vector_size =
2201         indices_shape.dimensions(dim_numbers.index_vector_dim());
2202     for (int64_t i = 0; i < index_vector_size; i++) {
2203       gather_index_index_components[dim_numbers.index_vector_dim()] =
2204           index.GetConstantWithIndexType(i);
2205       IrArray::Index gather_index_index(gather_index_index_components,
2206                                         indices_shape, index_type);
2207       TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
2208                           indices_generator(gather_index_index));
2209       add_to_operand_index(gather_dim_component, i);
2210     }
2211   }
2212   IrArray::Index operand_index(operand_multi_index, operand_shape, index_type);
2213   return operand_generator(operand_index);
2214 }
2215 
EmitElementalDynamicUpdateSlice(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)2216 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
2217     const HloInstruction* hlo,
2218     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2219     const llvm_ir::IrArray::Index& index) {
2220   const HloInstruction* input_hlo = hlo->operand(0);
2221   const HloInstruction* update_hlo = hlo->operand(1);
2222   const HloInstruction* start_hlo = hlo->operand(2);
2223   // Calculate slice start/end indices.
2224   const int64_t rank = input_hlo->shape().rank();
2225   std::vector<llvm::Value*> slice_start_multi_index(rank);
2226   std::vector<llvm::Value*> slice_limit_multi_index(rank);
2227   // Slice intersection gathers (ANDs) conditions on all ranks for which
2228   // 'input' is set to 'update'
2229   llvm::Value* slice_intersection = b_->getTrue();
2230 
2231   for (int64_t i = 0; i < rank; ++i) {
2232     llvm::Type* index_type = index[0]->getType();
2233     auto index_typed_const = [&](uint64_t c) -> llvm::Constant* {
2234       return llvm::ConstantInt::get(index_type, c);
2235     };
2236 
2237     llvm_ir::IrArray::Index zero_index(index_type);
2238     TF_ASSIGN_OR_RETURN(
2239         llvm::Value * start_index_value,
2240         operand_to_generator.at(hlo->operand(2 + i))(zero_index));
2241 
2242     // Clamp the start index so that the update region fits in the operand.
2243     // start_index = clamp(start_index, 0, input_dim_size - update_dim_size)
2244     start_index_value = SExtOrTrunc(start_index_value, index_type);
2245     llvm::Value* update_dim_size =
2246         index_typed_const(update_hlo->shape().dimensions(i));
2247     int64_t largest_valid_start_index =
2248         input_hlo->shape().dimensions(i) - update_hlo->shape().dimensions(i);
2249     CHECK_GE(largest_valid_start_index, 0);
2250 
2251     bool is_signed = ShapeUtil::ElementIsSigned(start_hlo->shape());
2252     start_index_value = EmitIntegralMin(
2253         index_typed_const(largest_valid_start_index),
2254         EmitIntegralMax(index_typed_const(0), start_index_value, is_signed),
2255         is_signed);
2256 
2257     start_index_value->setName(IrName(hlo, StrCat("start_idx", i)));
2258     slice_start_multi_index[i] = start_index_value;
2259     slice_limit_multi_index[i] =
2260         Add(slice_start_multi_index[i], update_dim_size);
2261 
2262     slice_intersection =
2263         And(slice_intersection, ICmpSGE(index[i], slice_start_multi_index[i]),
2264             "slice_intersection");
2265     slice_intersection =
2266         And(slice_intersection, ICmpSLT(index[i], slice_limit_multi_index[i]),
2267             "slice_intersection");
2268   }
2269 
2270   // Emit:
2271   // if (slice_intersection) -> return data from 'update'.
2272   // else                    -> return data from 'input'.
2273   llvm::AllocaInst* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
2274       llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
2275       "ret_value_addr", b_);
2276   llvm_ir::LlvmIfData if_data =
2277       llvm_ir::EmitIfThenElse(slice_intersection, "slice_intersection", b_);
2278 
2279   // Handle true BB (return data from 'update')
2280   SetToFirstInsertPoint(if_data.true_block, b_);
2281   // Compute update index for intersection case.
2282   std::vector<llvm::Value*> update_multi_index(rank);
2283   for (int64_t i = 0; i < rank; ++i) {
2284     update_multi_index[i] = Sub(index[i], slice_start_multi_index[i]);
2285   }
2286   llvm_ir::IrArray::Index update_index(update_multi_index, update_hlo->shape(),
2287                                        index.GetType());
2288   TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
2289                       operand_to_generator.at(update_hlo)(update_index));
2290   Store(true_value, ret_value_addr);
2291 
2292   // Handle false BB (return data from 'input')
2293   SetToFirstInsertPoint(if_data.false_block, b_);
2294   TF_ASSIGN_OR_RETURN(llvm::Value * false_value,
2295                       operand_to_generator.at(input_hlo)(index));
2296   Store(false_value, ret_value_addr);
2297 
2298   SetToFirstInsertPoint(if_data.after_block, b_);
2299   return Load(ret_value_addr->getAllocatedType(), ret_value_addr);
2300 }
2301 
EmitElementalPad(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & padded_index)2302 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
2303     const HloInstruction* hlo,
2304     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2305     const llvm_ir::IrArray::Index& padded_index) {
2306   std::vector<llvm::Value*> multi_index = padded_index.multidim();
2307   llvm::Value* in_bounds = b_->getTrue();
2308   for (size_t i = 0; i < multi_index.size(); ++i) {
2309     auto index_typed_const = [=](int64_t n) {
2310       return padded_index.GetConstantWithIndexType(n);
2311     };
2312     const auto& pad_dim = hlo->padding_config().dimensions(i);
2313     multi_index[i] =
2314         Sub(multi_index[i], index_typed_const(pad_dim.edge_padding_low()));
2315     in_bounds = And(in_bounds, ICmpSGE(multi_index[i], index_typed_const(0)),
2316                     "in_bounds");
2317     in_bounds =
2318         And(in_bounds,
2319             ICmpEQ(index_typed_const(0),
2320                    URem(multi_index[i],
2321                         index_typed_const(pad_dim.interior_padding() + 1))),
2322             "in_bounds");
2323     multi_index[i] =
2324         SDiv(multi_index[i], index_typed_const(pad_dim.interior_padding() + 1));
2325     in_bounds =
2326         And(in_bounds,
2327             ICmpSLT(multi_index[i],
2328                     index_typed_const(hlo->operand(0)->shape().dimensions(i))),
2329             "in_bounds");
2330   }
2331 
2332   // if (in_bounds) {
2333   //   ret_value = operand0[index];  // source
2334   // } else {
2335   //   ret_value = *operand1;        // padding
2336   // }
2337   llvm::AllocaInst* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
2338       llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
2339       "pad_result_addr", b_);
2340   llvm_ir::LlvmIfData if_data =
2341       llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
2342   SetToFirstInsertPoint(if_data.true_block, b_);
2343   llvm_ir::IrArray::Index index(multi_index, hlo->operand(0)->shape(),
2344                                 padded_index.GetType());
2345   TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2346                       operand_to_generator.at(hlo->operand(0))(index));
2347   Store(operand_value, ret_value_addr);
2348 
2349   SetToFirstInsertPoint(if_data.false_block, b_);
2350   TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
2351                       operand_to_generator.at(hlo->operand(1))(
2352                           IrArray::Index(index.GetType())));
2353   Store(padding_value, ret_value_addr);
2354 
2355   SetToFirstInsertPoint(if_data.after_block, b_);
2356   // Don't create phi(operand_value, padding_value) here, because invoking
2357   // operand_to_generator may create new basic blocks, making the parent
2358   // of operand_value or padding_value no longer a predecessor of
2359   // if_data.after_block.
2360   return Load(ret_value_addr->getAllocatedType(), ret_value_addr);
2361 }
2362 
EmitElementalDot(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & dot_result_index)2363 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
2364     const HloInstruction* hlo,
2365     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2366     const llvm_ir::IrArray::Index& dot_result_index) {
2367   auto lhs_generator = operand_to_generator.at(hlo->operand(0));
2368   auto rhs_generator = operand_to_generator.at(hlo->operand(1));
2369 
2370   const DotDimensionNumbers& dim_numbers = hlo->dot_dimension_numbers();
2371   int64_t lhs_contracting_dim = dim_numbers.lhs_contracting_dimensions(0);
2372   int64_t rhs_contracting_dim = dim_numbers.rhs_contracting_dimensions(0);
2373 
2374   int64_t contracted_dim_size =
2375       hlo->operand(0)->shape().dimensions(lhs_contracting_dim);
2376   int64_t lhs_dims = hlo->operand(0)->shape().dimensions_size();
2377   int64_t rhs_dims = hlo->operand(1)->shape().dimensions_size();
2378 
2379   llvm::Type* index_type = dot_result_index.GetType();
2380   auto index_typed_const = [&](uint64_t c) -> llvm::Constant* {
2381     return llvm::ConstantInt::get(index_type, c);
2382   };
2383 
2384   std::unique_ptr<llvm_ir::ForLoop> inner_loop = llvm_ir::ForLoop::EmitForLoop(
2385       IrName(hlo, "inner"), index_typed_const(0),
2386       index_typed_const(contracted_dim_size), index_typed_const(1), b_);
2387 
2388   SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), b_);
2389   PrimitiveType primitive_type = hlo->shape().element_type();
2390   llvm::Type* primitive_type_llvm =
2391       llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
2392   llvm::AllocaInst* accumulator_alloca =
2393       llvm_ir::EmitAllocaAtFunctionEntry(primitive_type_llvm, "dot_acc", b_);
2394   Store(llvm::Constant::getNullValue(primitive_type_llvm), accumulator_alloca);
2395 
2396   SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), b_);
2397 
2398   // This is the inner reduction loop for a dot operation that produces
2399   // one element in the output.  If the operands to the dot operation have
2400   // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E].
2401   // Given an output index [a,b,c,d,e] in the result, we compute:
2402   //   sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T))
2403 
2404   std::vector<llvm::Value*> lhs_multi_index, rhs_multi_index;
2405   for (int64_t i = 0; i < lhs_dims - 1; i++) {
2406     lhs_multi_index.push_back(dot_result_index[i]);
2407   }
2408   lhs_multi_index.insert(lhs_multi_index.begin() + lhs_contracting_dim,
2409                          inner_loop->GetIndVarValue());
2410   IrArray::Index lhs_index(lhs_multi_index, hlo->operand(0)->shape(),
2411                            index_type);
2412 
2413   int64_t num_batch_dims = dim_numbers.rhs_batch_dimensions_size();
2414   for (int64_t i = 0; i < num_batch_dims; i++) {
2415     rhs_multi_index.push_back(
2416         dot_result_index[dim_numbers.rhs_batch_dimensions(i)]);
2417   }
2418   for (int64_t i = 0; i < rhs_dims - 1 - num_batch_dims; i++) {
2419     rhs_multi_index.push_back(dot_result_index[lhs_dims - 1 + i]);
2420   }
2421   rhs_multi_index.insert(rhs_multi_index.begin() + rhs_contracting_dim,
2422                          inner_loop->GetIndVarValue());
2423   IrArray::Index rhs_index(rhs_multi_index, hlo->operand(1)->shape(),
2424                            index_type);
2425 
2426   llvm::Value* current_accumulator =
2427       Load(accumulator_alloca->getAllocatedType(), accumulator_alloca);
2428   TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
2429   TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
2430   llvm::Value* next_accumulator =
2431       EmitMulAdd(lhs_value, rhs_value, current_accumulator, primitive_type);
2432   Store(next_accumulator, accumulator_alloca);
2433 
2434   SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_);
2435   return Load(accumulator_alloca->getAllocatedType(), accumulator_alloca);
2436 }
2437 
MakeElementGenerator(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator)2438 llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
2439     const HloInstruction* hlo,
2440     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) {
2441   switch (hlo->opcode()) {
2442     case HloOpcode::kAbs:
2443     case HloOpcode::kRoundNearestAfz:
2444     case HloOpcode::kRoundNearestEven:
2445     case HloOpcode::kCeil:
2446     case HloOpcode::kClz:
2447     case HloOpcode::kConvert:
2448     case HloOpcode::kBitcastConvert:
2449     case HloOpcode::kCos:
2450     case HloOpcode::kExp:
2451     case HloOpcode::kExpm1:
2452     case HloOpcode::kFloor:
2453     case HloOpcode::kImag:
2454     case HloOpcode::kIsFinite:
2455     case HloOpcode::kLog:
2456     case HloOpcode::kLog1p:
2457     case HloOpcode::kNegate:
2458     case HloOpcode::kNot:
2459     case HloOpcode::kPopulationCount:
2460     case HloOpcode::kReal:
2461     case HloOpcode::kRsqrt:
2462     case HloOpcode::kSign:
2463     case HloOpcode::kSin:
2464     case HloOpcode::kSqrt:
2465     case HloOpcode::kCbrt:
2466     case HloOpcode::kTanh:
2467       return [this, hlo, &operand_to_generator](
2468                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2469         TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2470                             operand_to_generator.at(hlo->operand(0))(index));
2471         return EmitUnaryOp(hlo, operand_value);
2472       };
2473     case HloOpcode::kAdd:
2474     case HloOpcode::kAnd:
2475     case HloOpcode::kAtan2:
2476     case HloOpcode::kCompare:
2477     case HloOpcode::kComplex:
2478     case HloOpcode::kDivide:
2479     case HloOpcode::kMaximum:
2480     case HloOpcode::kMinimum:
2481     case HloOpcode::kMultiply:
2482     case HloOpcode::kOr:
2483     case HloOpcode::kXor:
2484     case HloOpcode::kPower:
2485     case HloOpcode::kRemainder:
2486     case HloOpcode::kShiftLeft:
2487     case HloOpcode::kShiftRightArithmetic:
2488     case HloOpcode::kShiftRightLogical:
2489     case HloOpcode::kSubtract:
2490       return [this, hlo, &operand_to_generator](
2491                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2492         const HloInstruction* lhs = hlo->operand(0);
2493         const HloInstruction* rhs = hlo->operand(1);
2494         TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value,
2495                             operand_to_generator.at(lhs)(index));
2496         TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value,
2497                             operand_to_generator.at(rhs)(index));
2498         return EmitBinaryOp(hlo, lhs_value, rhs_value);
2499       };
2500     case HloOpcode::kSelect:
2501       return [this, hlo, &operand_to_generator](
2502                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2503         return EmitElementalSelect(hlo, operand_to_generator, index);
2504       };
2505     case HloOpcode::kClamp:
2506       return [this, hlo, &operand_to_generator](
2507                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2508         return EmitElementalClamp(hlo, operand_to_generator, index);
2509       };
2510     case HloOpcode::kReducePrecision:
2511       return [this, hlo, &operand_to_generator](
2512                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2513         TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2514                             operand_to_generator.at(hlo->operand(0))(index));
2515         return EmitReducePrecision(hlo, operand_value);
2516       };
2517     case HloOpcode::kConcatenate:
2518       return [this, hlo, &operand_to_generator](
2519                  const IrArray::Index target_index) -> StatusOr<llvm::Value*> {
2520         return EmitElementalConcatenate(hlo, operand_to_generator,
2521                                         target_index);
2522       };
2523     case HloOpcode::kReverse:
2524       return [this, hlo, &operand_to_generator](
2525                  const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2526         const HloInstruction* operand = hlo->operand(0);
2527         std::vector<llvm::Value*> source_multi_index = target_index.multidim();
2528         for (int64_t dim : hlo->dimensions()) {
2529           source_multi_index[dim] = Sub(target_index.GetConstantWithIndexType(
2530                                             hlo->shape().dimensions(dim) - 1),
2531                                         target_index[dim]);
2532         }
2533         llvm_ir::IrArray::Index source_index(
2534             source_multi_index, operand->shape(), target_index.GetType());
2535         return operand_to_generator.at(operand)(source_index);
2536       };
2537     case HloOpcode::kBroadcast:
2538       return [this, hlo, &operand_to_generator](
2539                  const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2540         const HloInstruction* operand = hlo->operand(0);
2541         // The `dimensions` member of the broadcast instruction maps from
2542         // input dimensions to output dimensions.
2543         return operand_to_generator.at(operand)(
2544             target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(),
2545                                                 hlo->dimensions(), b_));
2546       };
2547     case HloOpcode::kIota:
2548       return [this, hlo](
2549                  const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2550         auto* iota = Cast<HloIotaInstruction>(hlo);
2551         PrimitiveType element_type = iota->shape().element_type();
2552         IrArray::Index elem_index =
2553             iota->shape().rank() > 1
2554                 ? target_index.SourceIndexOfBroadcast(
2555                       iota->shape(),
2556                       ShapeUtil::MakeShapeWithDescendingLayout(
2557                           element_type,
2558                           {iota->shape().dimensions(iota->iota_dimension())}),
2559                       {iota->iota_dimension()}, b_)
2560                 : target_index;
2561         llvm::Value* elem_index_linear = elem_index.linear();
2562         if (elem_index_linear == nullptr) {
2563           std::vector<int64_t> iota_bound = {
2564               iota->shape().dimensions(iota->iota_dimension())};
2565           elem_index_linear = elem_index.Linearize(iota_bound, b_);
2566         }
2567         Shape component_shape =
2568             ShapeUtil::ElementIsComplex(iota->shape())
2569                 ? ShapeUtil::ComplexComponentShape(iota->shape())
2570                 : iota->shape();
2571         PrimitiveType component_element_type = component_shape.element_type();
2572         llvm::Value* iota_result;
2573         if (primitive_util::IsIntegralType(component_element_type)) {
2574           iota_result = b_->CreateIntCast(
2575               elem_index_linear,
2576               llvm_ir::PrimitiveTypeToIrType(component_element_type, module_),
2577               /*isSigned=*/false);
2578         } else {
2579           TF_RET_CHECK(
2580               primitive_util::IsFloatingPointType(component_element_type))
2581               << component_element_type;
2582           llvm::Type* float_ir_type;
2583           if (component_element_type == BF16) {
2584             float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_);
2585           } else {
2586             float_ir_type =
2587                 llvm_ir::PrimitiveTypeToIrType(component_element_type, module_);
2588           }
2589           llvm::Value* float_val =
2590               b_->CreateUIToFP(elem_index_linear, float_ir_type);
2591           if (component_element_type == BF16) {
2592             TF_ASSIGN_OR_RETURN(iota_result, EmitF32ToBF16(float_val, b_));
2593           } else {
2594             iota_result = float_val;
2595           }
2596         }
2597         if (ShapeUtil::ElementIsComplex(iota->shape())) {
2598           return EmitComposeComplex(iota, iota_result, nullptr);
2599         } else {
2600           return iota_result;
2601         }
2602       };
2603     case HloOpcode::kSlice:
2604       return [this, hlo, &operand_to_generator](
2605                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2606         IrArray::Index sliced_index = index.SourceIndexOfSlice(
2607             /*operand_shape=*/hlo->operand(0)->shape(),
2608             /*starts=*/hlo->slice_starts(),
2609             /*strides=*/hlo->slice_strides(), /*builder=*/b_);
2610         return operand_to_generator.at(hlo->operand(0))(sliced_index);
2611       };
2612     case HloOpcode::kDynamicSlice:
2613       return [this, hlo, &operand_to_generator](
2614                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2615         return EmitElementalDynamicSlice(hlo, operand_to_generator, index);
2616       };
2617 
2618     case HloOpcode::kGather:
2619       return [this, hlo, &operand_to_generator](
2620                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2621         return EmitElementalGather(hlo, operand_to_generator, index);
2622       };
2623     case HloOpcode::kDynamicUpdateSlice:
2624       return [this, hlo, &operand_to_generator](
2625                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2626         return EmitElementalDynamicUpdateSlice(hlo, operand_to_generator,
2627                                                index);
2628       };
2629     case HloOpcode::kBitcast:
2630       CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
2631                ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
2632       return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2633         const HloInstruction* operand = hlo->operand(0);
2634         return operand_to_generator.at(operand)(
2635             GetSourceIndexOfBitcast(index, hlo));
2636       };
2637     case HloOpcode::kReshape:
2638       CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
2639                ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
2640       return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2641         const HloInstruction* operand = hlo->operand(0);
2642         return operand_to_generator.at(operand)(
2643             index.SourceIndexOfReshape(hlo->shape(), operand->shape(), b_));
2644       };
2645     case HloOpcode::kCopy:
2646       return [hlo, &operand_to_generator](
2647                  const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2648         IrArray::Index source_index(target_index.multidim(),
2649                                     hlo->operand(0)->shape(),
2650                                     target_index.GetType());
2651         TF_ASSIGN_OR_RETURN(
2652             llvm::Value * operand_value,
2653             operand_to_generator.at(hlo->operand(0))(source_index));
2654         return operand_value;
2655       };
2656     case HloOpcode::kTranspose:
2657       return [this, hlo,
2658               &operand_to_generator](const IrArray::Index& target_index) {
2659         return operand_to_generator.at(hlo->operand(0))(
2660             target_index.SourceIndexOfTranspose(
2661                 hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions()));
2662       };
2663     case HloOpcode::kPad:
2664       return [this, hlo, &operand_to_generator](
2665                  const IrArray::Index& padded_index) -> StatusOr<llvm::Value*> {
2666         return EmitElementalPad(hlo, operand_to_generator, padded_index);
2667       };
2668 
2669     case HloOpcode::kDot:
2670       return [this, hlo,
2671               &operand_to_generator](const IrArray::Index& dot_result_index)
2672                  -> StatusOr<llvm::Value*> {
2673         return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
2674       };
2675     case HloOpcode::kMap:
2676       return [this, hlo, &operand_to_generator](
2677                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2678         std::vector<llvm::Value*> operands;
2679         for (int i = 0; i < hlo->operand_count(); i++) {
2680           TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2681                               operand_to_generator.at(hlo->operand(i))(index));
2682           operands.push_back(operand_value);
2683         }
2684         return EmitElementalMap(Cast<HloMapInstruction>(hlo), operands);
2685       };
2686     case HloOpcode::kReduceWindow:
2687       return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2688         auto reduce_window_instr = Cast<HloReduceWindowInstruction>(hlo);
2689         std::vector<llvm_ir::ElementGenerator> input_generators;
2690         for (const HloInstruction* instr : reduce_window_instr->inputs()) {
2691           input_generators.push_back(operand_to_generator.at(instr));
2692         }
2693 
2694         std::vector<llvm_ir::ElementGenerator> initial_value_generators;
2695         for (const HloInstruction* instr : reduce_window_instr->init_values()) {
2696           initial_value_generators.push_back(operand_to_generator.at(instr));
2697         }
2698         return EmitElementalReduceWindow(
2699             Cast<HloReduceWindowInstruction>(hlo), std::move(input_generators),
2700             std::move(initial_value_generators), index);
2701       };
2702     case HloOpcode::kReduce:
2703       return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2704         auto reduce_instr = Cast<HloReduceInstruction>(hlo);
2705         std::vector<llvm_ir::ElementGenerator> input_generators;
2706         for (const HloInstruction* instr : reduce_instr->inputs()) {
2707           input_generators.push_back(operand_to_generator.at(instr));
2708         }
2709 
2710         std::vector<llvm_ir::ElementGenerator> initial_value_generators;
2711         for (const HloInstruction* instr : reduce_instr->init_values()) {
2712           initial_value_generators.push_back(operand_to_generator.at(instr));
2713         }
2714         return EmitElementalReduce(reduce_instr, std::move(input_generators),
2715                                    std::move(initial_value_generators), index);
2716       };
2717     case HloOpcode::kConvolution:
2718       return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2719         return EmitConvolution(hlo, operand_to_generator, index);
2720       };
2721     default:
2722       return [hlo](const IrArray::Index& index) {
2723         return Unimplemented("Unhandled opcode for elemental IR emission: %s",
2724                              HloOpcodeString(hlo->opcode()));
2725       };
2726   }
2727 }
2728 
EmitExtractReal(llvm::Value * value)2729 llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) {
2730   return ExtractValue(value, {0});
2731 }
2732 
EmitExtractImag(llvm::Value * value)2733 llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) {
2734   return ExtractValue(value, {1});
2735 }
2736 
EmitComposeComplex(const HloInstruction * op,llvm::Value * real,llvm::Value * imag)2737 llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
2738                                                     llvm::Value* real,
2739                                                     llvm::Value* imag) {
2740   auto cplx_type =
2741       llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
2742   auto complex =
2743       InsertValue(llvm::ConstantAggregateZero::get(cplx_type), real, {0});
2744   if (imag != nullptr) {
2745     complex = InsertValue(complex, imag, {1});
2746   }
2747   return complex;
2748 }
2749 
EmitMulAdd(llvm::Value * lhs,llvm::Value * rhs,llvm::Value * accumulator,xla::PrimitiveType primitive_type)2750 llvm::Value* ElementalIrEmitter::EmitMulAdd(llvm::Value* lhs, llvm::Value* rhs,
2751                                             llvm::Value* accumulator,
2752                                             xla::PrimitiveType primitive_type) {
2753   if (primitive_util::IsComplexType(primitive_type)) {
2754     llvm::Value* product_real =
2755         FSub(FMul(EmitExtractReal(lhs), EmitExtractReal(rhs)),
2756              FMul(EmitExtractImag(lhs), EmitExtractImag(rhs)));
2757     llvm::Value* product_imag =
2758         FAdd(FMul(EmitExtractReal(lhs), EmitExtractImag(rhs)),
2759              FMul(EmitExtractImag(lhs), EmitExtractReal(rhs)));
2760     llvm::Value* next_accumulator = InsertValue(
2761         accumulator, FAdd(EmitExtractReal(accumulator), product_real), {0});
2762     return InsertValue(next_accumulator,
2763                        FAdd(EmitExtractImag(accumulator), product_imag), {1});
2764   } else if (primitive_util::IsFloatingPointType(primitive_type)) {
2765     return FAdd(accumulator, FPCast(FMul(lhs, rhs), accumulator->getType()));
2766   } else if (primitive_type == PRED) {
2767     return Or(accumulator, And(lhs, rhs));
2768   }
2769   return Add(accumulator, Mul(lhs, rhs));
2770 }
2771 
EmitElementalMap(const HloMapInstruction * map_instr,absl::Span<llvm::Value * const> elemental_operands)2772 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalMap(
2773     const HloMapInstruction* map_instr,
2774     absl::Span<llvm::Value* const> elemental_operands) {
2775   TF_ASSIGN_OR_RETURN(
2776       std::vector<llvm::Value*> values,
2777       EmitThreadLocalCall(*map_instr->to_apply(), elemental_operands,
2778                           llvm_ir::IrName(map_instr), /*is_reducer=*/false));
2779   CHECK_EQ(values.size(), 1);
2780   return values[0];
2781 }
2782 
EmitElementalReduceWindow(const HloReduceWindowInstruction * reduce_window,std::vector<llvm_ir::ElementGenerator> input_generators,std::vector<llvm_ir::ElementGenerator> initial_value_generators,const llvm_ir::IrArray::Index & index)2783 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduceWindow(
2784     const HloReduceWindowInstruction* reduce_window,
2785     std::vector<llvm_ir::ElementGenerator> input_generators,
2786     std::vector<llvm_ir::ElementGenerator> initial_value_generators,
2787     const llvm_ir::IrArray::Index& index) {
2788   // Pseudocode:
2789   // for each index I in output
2790   //   value = init_value
2791   //   for each index W in window
2792   //     for each dimension i from 0 to rank - 1
2793   //       (input index I)[i] = O[i] * stride[i] + W[i] - pad_low[i]
2794   //     if I in bounds of input
2795   //       value = function(value, input[I])
2796   //     output[O] = value
2797   int64_t input_count = reduce_window->input_count();
2798   std::vector<PrimitiveType> operand_element_types;
2799   std::vector<llvm::Type*> accum_types;
2800   std::vector<llvm::Value*> accum_ptrs;
2801   for (int64_t operand_index = 0; operand_index < input_count;
2802        ++operand_index) {
2803     auto operand = reduce_window->inputs()[operand_index];
2804     PrimitiveType operand_element_type = operand->shape().element_type();
2805     operand_element_types.push_back(operand_element_type);
2806     llvm::Type* llvm_type =
2807         llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_);
2808     accum_types.push_back(llvm_type);
2809     llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
2810         llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
2811         "reduce_window_accum_ptr", b_);
2812     accum_ptrs.push_back(accum_ptr);
2813     {
2814       auto initial_value_generator = initial_value_generators[operand_index];
2815       TF_ASSIGN_OR_RETURN(
2816           llvm::Value* const init_value,
2817           initial_value_generator(llvm_ir::IrArray::Index(index.GetType())));
2818       Store(init_value, accum_ptr);
2819     }
2820   }
2821 
2822   llvm::Type* index_type = index.GetType();
2823   auto index_typed_const = [&](uint64_t c) -> llvm::Constant* {
2824     return index.GetConstantWithIndexType(c);
2825   };
2826 
2827   const Window& window = reduce_window->window();
2828   llvm_ir::ForLoopNest loops(IrName(reduce_window), b_, index_type);
2829   std::vector<int64_t> window_size;
2830   const auto& dimensions = window.dimensions();
2831   window_size.reserve(dimensions.size());
2832   for (const auto& dim : dimensions) {
2833     window_size.push_back(dim.size());
2834   }
2835   const IrArray::Index window_index = loops.AddLoopsForShape(
2836       ShapeUtil::MakeShape(operand_element_types[0], window_size), "window");
2837   CHECK_EQ(window_index.size(), index.size());
2838 
2839   SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
2840 
2841   std::vector<llvm::Value*> input_multi_index(index.size());
2842   llvm::Value* in_bounds = b_->getInt1(true);
2843   for (size_t i = 0; i < index.size(); ++i) {
2844     llvm::Value* stridden_index =
2845         NSWMul(index[i], index_typed_const(window.dimensions(i).stride()));
2846     input_multi_index[i] = NSWSub(
2847         NSWAdd(
2848             stridden_index,
2849             NSWMul(window_index[i],
2850                    index_typed_const(window.dimensions(i).window_dilation()))),
2851         index_typed_const(window.dimensions(i).padding_low()));
2852 
2853     // We need to verify that we are not in the dilated base area.
2854     llvm::Value* dilation_condition =
2855         ICmpEQ(SRem(input_multi_index[i],
2856                     index_typed_const(window.dimensions(i).base_dilation())),
2857                index_typed_const(0));
2858     in_bounds = And(in_bounds, dilation_condition);
2859 
2860     // Apply base dilation to the index.
2861     input_multi_index[i] =
2862         SDiv(input_multi_index[i],
2863              index_typed_const(window.dimensions(i).base_dilation()));
2864 
2865     // We must check whether 0 <= input_multi_index[i] < bound, as
2866     // otherwise we are in the pad and so can skip the computation. This
2867     // comparison is equivalent to the unsigned comparison
2868     // input_multi_index[i] < bound, as a negative value wraps to a large
2869     // positive value.
2870     in_bounds =
2871         And(in_bounds,
2872             ICmpULT(input_multi_index[i],
2873                     index_typed_const(
2874                         reduce_window->inputs()[0]->shape().dimensions(i))));
2875   }
2876 
2877   llvm_ir::LlvmIfData if_data =
2878       llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
2879   SetToFirstInsertPoint(if_data.true_block, b_);
2880 
2881   // We are not in pad, so do the computation.
2882   std::vector<llvm::Value*> input_values(reduce_window->operand_count());
2883   IrArray::Index input_index(input_multi_index,
2884                              reduce_window->inputs()[0]->shape(), index_type);
2885   for (int64_t operand_idx = 0; operand_idx < input_count; ++operand_idx) {
2886     TF_ASSIGN_OR_RETURN(llvm::Value * input_value,
2887                         input_generators[operand_idx](input_index));
2888     input_values[input_count + operand_idx] = input_value;
2889     input_values[operand_idx] =
2890         Load(llvm::cast<llvm::AllocaInst>(accum_ptrs[operand_idx])
2891                  ->getAllocatedType(),
2892              accum_ptrs[operand_idx]);
2893   }
2894   TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accum_values,
2895                       EmitThreadLocalCall(*reduce_window->to_apply(),
2896                                           input_values, "reducer_function",
2897                                           /*is_reducer=*/true));
2898 
2899   for (int64_t operand_idx = 0; operand_idx < accum_values.size();
2900        ++operand_idx) {
2901     Store(accum_values[operand_idx], accum_ptrs[operand_idx]);
2902   }
2903 
2904   SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
2905   return EmitAccumResult(accum_ptrs, accum_types,
2906                          reduce_window->shape().IsTuple());
2907 }
2908 
EmitElementalReduce(const HloReduceInstruction * reduce,std::vector<llvm_ir::ElementGenerator> input_generators,std::vector<llvm_ir::ElementGenerator> initial_value_generators,const llvm_ir::IrArray::Index & index)2909 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduce(
2910     const HloReduceInstruction* reduce,
2911     std::vector<llvm_ir::ElementGenerator> input_generators,
2912     std::vector<llvm_ir::ElementGenerator> initial_value_generators,
2913     const llvm_ir::IrArray::Index& index) {
2914   const Shape& out_shape = reduce->shape();
2915   bool is_variadic = !out_shape.IsArray();
2916   int accumulators_count = 1;
2917   if (is_variadic) {
2918     CHECK(out_shape.IsTuple());
2919     accumulators_count = out_shape.tuple_shapes_size();
2920   }
2921 
2922   absl::Span<const int64_t> reduced_dimensions(reduce->dimensions());
2923 
2924   std::vector<llvm::Value*> accumulator_addrs;
2925   std::vector<llvm::Type*> accumulator_types;
2926   llvm::Type* index_type = index.GetType();
2927   for (int i = 0; i < accumulators_count; i++) {
2928     const Shape& element_shape =
2929         is_variadic ? out_shape.tuple_shapes(i) : out_shape;
2930     PrimitiveType accumulator_type = element_shape.element_type();
2931     llvm::Type* accumulator_llvm_type =
2932         llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_);
2933     accumulator_types.push_back(accumulator_llvm_type);
2934 
2935     // Initialize an accumulator with init_value.
2936     llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
2937         accumulator_llvm_type, "accumulator_" + std::to_string(i), b());
2938     TF_ASSIGN_OR_RETURN(
2939         llvm::Value* const init_value,
2940         initial_value_generators[i](llvm_ir::IrArray::Index(index_type)));
2941     Store(init_value, accumulator_addr);
2942     accumulator_addrs.push_back(accumulator_addr);
2943   }
2944 
2945   // The enclosing loops go over all the target elements. Now we have to compute
2946   // the actual target element. For this, we build a new loop nest to iterate
2947   // over all the reduction dimensions in the argument.
2948   // AddLoopsForShapeOnDimensions will return an Index where induction Value*s
2949   // are placed for each dimension in dimensions, and all the rest are nullptrs.
2950   llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), b(), index_type);
2951   const HloInstruction* arg = reduce->operand(0);
2952   std::vector<llvm::Value*> input_multi_index =
2953       loops.AddLoopsForShapeOnDimensions(arg->shape(), reduced_dimensions,
2954                                          "reduction_dim");
2955 
2956   SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b());
2957 
2958   // Build a full index for the input argument, using input_multi_index as the
2959   // base. In input_multi_index only the reduction dimensions are filled in. We
2960   // fill in the rest of the dimensions with induction Value*s taken from
2961   // 'index' which iterates over the target array.  See the high-level
2962   // description in the XLA documentation for details.
2963   auto it = index.begin();
2964 
2965   for (auto& i : input_multi_index) {
2966     if (i == nullptr) {
2967       i = *it++;
2968     }
2969   }
2970   CHECK(index.end() == it);
2971   llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
2972                                       index_type);
2973 
2974   std::vector<llvm::Value*> reduction_operands;
2975   for (llvm::Value* accum : accumulator_addrs) {
2976     llvm::Value* accum_value =
2977         Load(llvm::cast<llvm::AllocaInst>(accum)->getAllocatedType(), accum);
2978     reduction_operands.push_back(accum_value);
2979   }
2980 
2981   for (int i = 0; i < accumulators_count; i++) {
2982     TF_ASSIGN_OR_RETURN(llvm::Value* const input_element,
2983                         input_generators[i](input_index));
2984     reduction_operands.push_back(input_element);
2985   }
2986 
2987   TF_ASSIGN_OR_RETURN(
2988       std::vector<llvm::Value*> results,
2989       EmitThreadLocalCall(*reduce->to_apply(), reduction_operands,
2990                           "reduce_function", /*is_reducer=*/true));
2991 
2992   CHECK(results.size() == accumulators_count);
2993   for (int i = 0; i < accumulators_count; i++) {
2994     Store(results[i], accumulator_addrs[i]);
2995   }
2996   SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b());
2997   return EmitAccumResult(accumulator_addrs, accumulator_types, is_variadic);
2998 }
2999 
EmitAccumResult(absl::Span<llvm::Value * const> accumulator_addrs,llvm::ArrayRef<llvm::Type * > accumulator_types,bool is_variadic)3000 StatusOr<llvm::Value*> ElementalIrEmitter::EmitAccumResult(
3001     absl::Span<llvm::Value* const> accumulator_addrs,
3002     llvm::ArrayRef<llvm::Type*> accumulator_types, bool is_variadic) {
3003   TF_RET_CHECK(accumulator_addrs.size() == accumulator_types.size());
3004   if (is_variadic) {
3005     // Emit a structure, as that what the LoopEmitter expects.
3006     llvm::Value* returned_structure = llvm::UndefValue::get(
3007         llvm::StructType::get(b()->getContext(), accumulator_types));
3008     for (int64_t i = 0; i < accumulator_addrs.size(); i++) {
3009       llvm::Value* accumulator_value =
3010           Load(accumulator_types[i], accumulator_addrs[i]);
3011       returned_structure =
3012           b()->CreateInsertValue(returned_structure, accumulator_value, i);
3013     }
3014     return returned_structure;
3015   } else {
3016     CHECK_EQ(accumulator_addrs.size(), 1);
3017     return Load(accumulator_types[0], accumulator_addrs[0]);
3018   }
3019 }
3020 
EmitConvolution(const HloInstruction * convolution,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)3021 StatusOr<llvm::Value*> ElementalIrEmitter::EmitConvolution(
3022     const HloInstruction* convolution,
3023     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
3024     const llvm_ir::IrArray::Index& index) {
3025   TF_RET_CHECK(convolution->batch_group_count() == 1);
3026   const HloInstruction* lhs = convolution->operand(0);
3027   const auto& input_generator = operand_to_generator.at(lhs);
3028   const HloInstruction* rhs = convolution->operand(1);
3029   const auto& kernel_generator = operand_to_generator.at(rhs);
3030   const Window& window = convolution->window();
3031 
3032   const ConvolutionDimensionNumbers& dnums =
3033       convolution->convolution_dimension_numbers();
3034   int num_spatial_dims = dnums.output_spatial_dimensions_size();
3035   std::vector<llvm::Value*> output_spatial(num_spatial_dims);
3036   for (int i = 0; i < num_spatial_dims; ++i) {
3037     output_spatial[i] = index[dnums.output_spatial_dimensions(i)];
3038   }
3039   llvm::Value* output_feature = index[dnums.output_feature_dimension()];
3040   llvm::Value* batch = index[dnums.output_batch_dimension()];
3041 
3042   // We will accumulate the products into this sum to calculate the output entry
3043   // at the given index.
3044   PrimitiveType lhs_element_type = lhs->shape().element_type();
3045   llvm::Type* lhs_llvm_type =
3046       llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
3047   // Upcast the accumulator to F32 from F16 for increased precision.
3048   llvm::Type* accumulator_type =
3049       lhs_element_type == F16 ? b_->getFloatTy() : lhs_llvm_type;
3050   llvm::AllocaInst* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
3051       accumulator_type, "convolution_sum_address", b_);
3052   llvm::Value* constant_zero = llvm::Constant::getNullValue(accumulator_type);
3053   Store(constant_zero, sum_address);
3054 
3055   llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), b_);
3056   std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
3057   for (int i = 0; i < num_spatial_dims; ++i) {
3058     kernel_spatial[i] =
3059         loops
3060             .AddLoop(
3061                 0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)),
3062                 absl::StrCat("k", i))
3063             ->GetIndVarValue();
3064   }
3065   const int64_t input_group_size =
3066       rhs->shape().dimensions(dnums.kernel_input_feature_dimension());
3067   const int64_t feature_group_count = convolution->feature_group_count();
3068   const int64_t output_group_size =
3069       rhs->shape().dimensions(dnums.kernel_output_feature_dimension()) /
3070       feature_group_count;
3071   llvm::Value* input_feature =
3072       loops.AddLoop(0, input_group_size, "iz")->GetIndVarValue();
3073 
3074   SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
3075 
3076   llvm::Value* group_id = SDiv(output_feature, b_->getInt64(output_group_size));
3077   llvm::Value* lhs_input_feature =
3078       NSWAdd(input_feature, NSWMul(group_id, b_->getInt64(input_group_size)));
3079 
3080   // Calculate the spatial index in the input array, taking striding, dilation
3081   // and padding into account. An index in the padding will be out of the bounds
3082   // of the array.
3083   const auto calculate_input_index = [this](llvm::Value* output_index,
3084                                             llvm::Value* kernel_index,
3085                                             const WindowDimension& window_dim) {
3086     llvm::Value* strided_index =
3087         NSWMul(output_index, b_->getInt64(window_dim.stride()));
3088     llvm::Value* dilated_kernel_index =
3089         NSWMul(kernel_index, b_->getInt64(window_dim.window_dilation()));
3090     return NSWSub(NSWAdd(strided_index, dilated_kernel_index),
3091                   b_->getInt64(window_dim.padding_low()));
3092   };
3093   std::vector<llvm::Value*> input_spatial(num_spatial_dims);
3094   for (int i = 0; i < num_spatial_dims; ++i) {
3095     input_spatial[i] = calculate_input_index(
3096         output_spatial[i], kernel_spatial[i], window.dimensions(i));
3097   }
3098 
3099   // We need to check if 0 <= input dim < bound, as otherwise we are in the
3100   // padding so that we can skip the computation. That is equivalent to input
3101   // dim < bound as an *unsigned* comparison, since a negative value will wrap
3102   // to a large positive value. The input dim is dilated, so we need to dilate
3103   // the bound as well to match.
3104 
3105   // Also need to check that the input coordinates are not in one of the
3106   // holes created by base dilation.
3107   const auto not_in_hole = [&](llvm::Value* input_index,
3108                                int64_t base_dilation) {
3109     llvm::Value* remainder = SRem(input_index, b_->getInt64(base_dilation));
3110     return ICmpEQ(remainder, b_->getInt64(0));
3111   };
3112 
3113   llvm::Value* in_bounds_condition = b_->getInt1(true);
3114   for (int i = 0; i < num_spatial_dims; ++i) {
3115     llvm::ConstantInt* input_bound = b_->getInt64(window_util::DilatedBound(
3116         lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
3117         window.dimensions(i).base_dilation()));
3118     llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound);
3119     llvm::Value* dim_not_in_hole =
3120         not_in_hole(input_spatial[i], window.dimensions(i).base_dilation());
3121     llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole);
3122     in_bounds_condition = And(in_bounds_condition, dim_ok);
3123   }
3124 
3125   // Now we need to map the dilated base coordinates back to the actual
3126   // data indices on the lhs.
3127   const auto undilate = [&](llvm::Value* input_index, int64_t base_dilation) {
3128     return SDiv(input_index, b_->getInt64(base_dilation));
3129   };
3130   for (int i = 0; i < num_spatial_dims; ++i) {
3131     input_spatial[i] =
3132         undilate(input_spatial[i], window.dimensions(i).base_dilation());
3133   }
3134 
3135   llvm_ir::LlvmIfData if_data =
3136       llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", b_);
3137   SetToFirstInsertPoint(if_data.true_block, b_);
3138 
3139   // We are not in the padding, so carry out the computation.
3140   int num_dims = num_spatial_dims + 2;
3141   std::vector<llvm::Value*> input_multi_index(num_dims);
3142   for (int i = 0; i < num_spatial_dims; ++i) {
3143     input_multi_index[dnums.input_spatial_dimensions(i)] = input_spatial[i];
3144   }
3145   input_multi_index[dnums.input_feature_dimension()] = lhs_input_feature;
3146   input_multi_index[dnums.input_batch_dimension()] = batch;
3147 
3148   std::vector<llvm::Value*> kernel_multi_index(num_dims);
3149   for (int i = 0; i < num_spatial_dims; ++i) {
3150     kernel_multi_index[dnums.kernel_spatial_dimensions(i)] =
3151         window.dimensions(i).window_reversal()
3152             ? NSWSub(b_->getInt64(window.dimensions(i).size() - 1),
3153                      kernel_spatial[i])
3154             : kernel_spatial[i];
3155   }
3156 
3157   kernel_multi_index[dnums.kernel_input_feature_dimension()] = input_feature;
3158   kernel_multi_index[dnums.kernel_output_feature_dimension()] = output_feature;
3159 
3160   llvm_ir::IrArray::Index input_index(input_multi_index, lhs->shape(),
3161                                       b_->getInt64Ty());
3162   TF_ASSIGN_OR_RETURN(llvm::Value* const input_value,
3163                       input_generator(input_index));
3164   llvm_ir::IrArray::Index kernel_index(kernel_multi_index, rhs->shape(),
3165                                        b_->getInt64Ty());
3166   TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value,
3167                       kernel_generator(kernel_index));
3168   llvm::Value* sum =
3169       EmitMulAdd(input_value, kernel_value,
3170                  Load(sum_address->getAllocatedType(), sum_address),
3171                  convolution->shape().element_type());
3172   Store(sum, sum_address);
3173 
3174   SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
3175   return FPCast(Load(sum_address->getAllocatedType(), sum_address),
3176                 lhs_llvm_type);
3177 }
3178 
3179 // Evaluate polynomial using Horner's method.
EvaluatePolynomial(llvm::Type * type,llvm::Value * x,absl::Span<const double> coefficients)3180 StatusOr<llvm::Value*> ElementalIrEmitter::EvaluatePolynomial(
3181     llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients) {
3182   llvm::Value* poly = llvm::ConstantFP::get(type, 0.0);
3183   for (const double c : coefficients) {
3184     poly = FAdd(FMul(poly, x), llvm::ConstantFP::get(type, c));
3185   }
3186   return poly;
3187 }
3188 
3189 }  // namespace xla
3190