1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_ 18 19 #include <vector> 20 21 #include "absl/container/flat_hash_map.h" 22 #include "absl/strings/string_view.h" 23 #include "absl/types/span.h" 24 #include "llvm/IR/IRBuilder.h" 25 #include "llvm/IR/Module.h" 26 #include "llvm/IR/Value.h" 27 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 28 #include "tensorflow/compiler/xla/service/hlo_instructions.h" 29 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 30 #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" 31 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" 32 #include "tensorflow/compiler/xla/statusor.h" 33 34 namespace xla { 35 36 class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> { 37 public: 38 using HloToElementGeneratorMap = 39 absl::flat_hash_map<const HloInstruction*, llvm_ir::ElementGenerator>; 40 ElementalIrEmitter(llvm::Module * module,llvm::IRBuilder<> * b)41 ElementalIrEmitter(llvm::Module* module, llvm::IRBuilder<>* b) 42 : b_(b), module_(module) {} 43 44 virtual ~ElementalIrEmitter() = default; 45 46 // Returns a function to generate an element of the output of `hlo`, given a 47 // map of functions to generate elements of its operands. 48 llvm_ir::ElementGenerator MakeElementGenerator( 49 const HloInstruction* hlo, 50 const HloToElementGeneratorMap& operand_to_generator); 51 b()52 llvm::IRBuilder<>* b() { return b_; } 53 54 // builder() is for IrBuilderMixin. builder()55 llvm::IRBuilder<>* builder() { return b_; } 56 module()57 llvm::Module* module() { return module_; } 58 59 protected: GetSourceIndexOfBitcast(const llvm_ir::IrArray::Index & index,const HloInstruction * hlo)60 virtual llvm_ir::IrArray::Index GetSourceIndexOfBitcast( 61 const llvm_ir::IrArray::Index& index, const HloInstruction* hlo) { 62 return index.SourceIndexOfBitcast(hlo->shape(), hlo->operand(0)->shape(), 63 b_); 64 } 65 66 virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op, 67 llvm::Value* lhs_value, 68 llvm::Value* rhs_value); 69 70 virtual llvm::Value* EmitExtractReal(llvm::Value* value); 71 virtual llvm::Value* EmitExtractImag(llvm::Value* value); 72 73 private: 74 virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op, 75 llvm::Value* operand_value); 76 77 virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op, 78 llvm::Value* lhs_value, 79 llvm::Value* rhs_value); 80 81 virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(const HloInstruction* op, 82 llvm::Value* operand_value); 83 84 virtual StatusOr<llvm::Value*> EmitFloatUnaryOp(const HloInstruction* op, 85 llvm::Value* operand_value); 86 87 virtual StatusOr<llvm::Value*> EmitComplexUnaryOp(const HloInstruction* op, 88 llvm::Value* operand_value); 89 90 llvm::Value* IsZero(llvm::Value* v); 91 llvm::Value* IsIntMinDivisionOverflow(llvm::Value* lhs, llvm::Value* rhs); 92 llvm::Value* GetZero(llvm::Type* type); 93 llvm::Value* GetOne(llvm::Type* type); 94 llvm::Value* GetIntSMin(llvm::Type* type); 95 llvm::Value* GetMinusOne(llvm::Type* type); 96 97 llvm::Value* EmitIntegerDivide(llvm::Value* lhs, llvm::Value* rhs, 98 bool is_signed); 99 llvm::Value* EmitIntegerRemainder(llvm::Value* lhs, llvm::Value* rhs, 100 bool is_signed); 101 llvm::Value* EmitIntegerPow(llvm::Value* lhs, llvm::Value* rhs, 102 bool is_signed); 103 104 virtual StatusOr<llvm::Value*> EmitPredBinaryOp(const HloInstruction* op, 105 llvm::Value* lhs_value, 106 llvm::Value* rhs_value); 107 108 virtual StatusOr<llvm::Value*> EmitIntegerBinaryOp(const HloInstruction* op, 109 llvm::Value* lhs_value, 110 llvm::Value* rhs_value, 111 bool is_signed); 112 113 virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(const HloInstruction* op, 114 llvm::Value* lhs_value, 115 llvm::Value* rhs_value); 116 117 virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value, 118 llvm::Value* rhs_value, 119 absl::string_view name); 120 121 virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value, 122 llvm::Value* rhs_value, 123 absl::string_view name); 124 125 llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value, 126 bool is_signed); 127 128 llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value, 129 bool is_signed); 130 131 virtual StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, 132 llvm::Value* lhs, llvm::Value* rhs, 133 absl::string_view name); 134 135 virtual StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type, 136 llvm::Value* value); 137 138 virtual StatusOr<llvm::Value*> EmitSqrt(PrimitiveType prim_type, 139 llvm::Value* value); 140 141 virtual StatusOr<llvm::Value*> EmitCbrt(PrimitiveType prim_type, 142 llvm::Value* value); 143 144 virtual StatusOr<llvm::Value*> EmitRsqrt(PrimitiveType prim_type, 145 llvm::Value* value); 146 147 virtual StatusOr<llvm::Value*> EmitLog1p(PrimitiveType prim_type, 148 llvm::Value* value); 149 150 virtual StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type, 151 llvm::Value* value); 152 153 virtual StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type, 154 llvm::Value* value); 155 156 virtual StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type, 157 llvm::Value* value, 158 absl::string_view name); 159 160 virtual StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type, 161 llvm::Value* value); 162 163 virtual StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type, 164 llvm::Value* lhs, llvm::Value* rhs, 165 absl::string_view name); 166 167 virtual StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type, 168 llvm::Value* value); 169 170 virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo, 171 llvm::Value* x); 172 173 virtual StatusOr<std::tuple<llvm::Value*, llvm::Value*, llvm::Value*>> 174 EmitComplexAbsHelper(PrimitiveType prim_type, llvm::Value* operand_value, 175 bool return_sqrt); 176 177 virtual StatusOr<llvm::Value*> EmitComplexAbs(PrimitiveType prim_type, 178 llvm::Value* operand_value); 179 180 virtual StatusOr<llvm::Value*> EmitSqrtComplexAbs(PrimitiveType prim_type, 181 llvm::Value* operand_value); 182 virtual StatusOr<llvm::Value*> EmitRsqrtComplexAbs( 183 PrimitiveType prim_type, llvm::Value* operand_value); 184 185 virtual StatusOr<llvm::Value*> EmitComplexAdd(const HloInstruction* op, 186 llvm::Value* lhs_value, 187 llvm::Value* rhs_value); 188 189 virtual StatusOr<llvm::Value*> EmitComplexSubtract(const HloInstruction* op, 190 llvm::Value* lhs_value, 191 llvm::Value* rhs_value); 192 193 virtual StatusOr<llvm::Value*> EmitComplexMultiply(const HloInstruction* op, 194 llvm::Value* lhs_value, 195 llvm::Value* rhs_value); 196 197 virtual StatusOr<llvm::Value*> EmitComplexDivide(const HloInstruction* op, 198 llvm::Value* lhs_value, 199 llvm::Value* rhs_value); 200 201 virtual StatusOr<llvm::Value*> EmitComplexLog(const HloInstruction* op, 202 llvm::Value* operand_value); 203 204 virtual StatusOr<llvm::Value*> EmitComplexSqrt(const HloInstruction* op, 205 PrimitiveType prim_type, 206 llvm::Value* operand_value); 207 208 virtual StatusOr<llvm::Value*> EmitComplexCbrt(const HloInstruction* op, 209 PrimitiveType prim_type, 210 llvm::Value* operand_value); 211 212 virtual StatusOr<llvm::Value*> EmitComplexRsqrt(const HloInstruction* op, 213 PrimitiveType prim_type, 214 llvm::Value* operand_value); 215 216 StatusOr<llvm::Value*> EmitAccumResult( 217 absl::Span<llvm::Value* const> accumulator_addrs, 218 llvm::ArrayRef<llvm::Type*> accumulator_types, bool is_variadic); 219 220 // Composes a complex struct. imag may be nullptr for simple cast operations. 221 llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real, 222 llvm::Value* imag); 223 224 // Emit `accumulator + lhs * rhs` for the given primitive type. 225 llvm::Value* EmitMulAdd(llvm::Value* lhs, llvm::Value* rhs, 226 llvm::Value* accumulator, 227 xla::PrimitiveType primitive_type); 228 229 // Identifier of the thread unique among all threads on the device EmitThreadId()230 virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); } 231 232 StatusOr<llvm::Value*> EmitElementalSelect( 233 const HloInstruction* hlo, 234 const HloToElementGeneratorMap& operand_to_generator, 235 const llvm_ir::IrArray::Index& index); 236 237 StatusOr<llvm::Value*> EmitElementalClamp( 238 const HloInstruction* hlo, 239 const HloToElementGeneratorMap& operand_to_generator, 240 const llvm_ir::IrArray::Index& index); 241 242 StatusOr<llvm::Value*> EmitElementalConcatenate( 243 const HloInstruction* hlo, 244 const HloToElementGeneratorMap& operand_to_generator, 245 const llvm_ir::IrArray::Index& target_index); 246 247 StatusOr<llvm::Value*> EmitElementalDynamicSlice( 248 const HloInstruction* hlo, 249 const HloToElementGeneratorMap& operand_to_generator, 250 const llvm_ir::IrArray::Index& index); 251 252 StatusOr<llvm::Value*> EmitElementalGather( 253 const HloInstruction* hlo, 254 const HloToElementGeneratorMap& operand_to_generator, 255 const llvm_ir::IrArray::Index& index); 256 257 StatusOr<llvm::Value*> EmitElementalDynamicUpdateSlice( 258 const HloInstruction* hlo, 259 const HloToElementGeneratorMap& operand_to_generator, 260 const llvm_ir::IrArray::Index& index); 261 262 StatusOr<llvm::Value*> EmitElementalPad( 263 const HloInstruction* hlo, 264 const HloToElementGeneratorMap& operand_to_generator, 265 const llvm_ir::IrArray::Index& padded_index); 266 267 StatusOr<llvm::Value*> EmitElementalDot( 268 const HloInstruction* hlo, 269 const HloToElementGeneratorMap& operand_to_generator, 270 const llvm_ir::IrArray::Index& dot_result_index); 271 272 virtual StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall( 273 const HloComputation& callee, absl::Span<llvm::Value* const> parameters, 274 absl::string_view name, bool is_reducer) = 0; 275 276 StatusOr<llvm::Value*> EmitElementalMap( 277 const HloMapInstruction* map_instr, 278 absl::Span<llvm::Value* const> elemental_operands); 279 280 StatusOr<llvm::Value*> EmitElementalReduceWindow( 281 const HloReduceWindowInstruction* reduce_window, 282 std::vector<llvm_ir::ElementGenerator> input_generators, 283 std::vector<llvm_ir::ElementGenerator> initial_value_generators, 284 const llvm_ir::IrArray::Index& index); 285 286 StatusOr<llvm::Value*> EmitElementalReduce( 287 const HloReduceInstruction* reduce, 288 std::vector<llvm_ir::ElementGenerator> input_generators, 289 std::vector<llvm_ir::ElementGenerator> initial_value_generators, 290 const llvm_ir::IrArray::Index& index); 291 292 virtual StatusOr<llvm::Value*> EmitConvolution( 293 const HloInstruction* hlo, 294 const HloToElementGeneratorMap& operand_to_generator, 295 const llvm_ir::IrArray::Index& index); 296 297 // Computes the complex power function, returns (a + i*b)^(c + i*d). 298 StatusOr<llvm::Value*> EmitComplexPower(const HloInstruction* op, 299 llvm::Value* a, llvm::Value* b, 300 llvm::Value* c, llvm::Value* d); 301 302 // Evaluates a polynomial using Horner's method. 303 StatusOr<llvm::Value*> EvaluatePolynomial( 304 llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients); 305 306 virtual bool fast_min_max() = 0; 307 308 llvm::IRBuilder<>* const b_; 309 310 llvm::Module* module_; 311 }; 312 313 } // namespace xla 314 315 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_ 316