1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_ 18 19 #include <stddef.h> 20 21 #include <functional> 22 #include <map> 23 #include <memory> 24 #include <ostream> 25 #include <string> 26 #include <vector> 27 28 #include "absl/container/flat_hash_map.h" 29 #include "absl/strings/string_view.h" 30 #include "absl/types/span.h" 31 #include "llvm/ADT/Triple.h" 32 #include "llvm/IR/Function.h" 33 #include "llvm/IR/IRBuilder.h" 34 #include "llvm/IR/Module.h" 35 #include "llvm/IR/Value.h" 36 #include "llvm/Target/TargetMachine.h" 37 #include "mlir/IR/MLIRContext.h" // from @llvm-project 38 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 39 #include "tensorflow/compiler/xla/service/cpu/ir_function.h" 40 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" 41 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 42 #include "tensorflow/compiler/xla/service/hlo_computation.h" 43 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 44 #include "tensorflow/compiler/xla/service/hlo_instructions.h" 45 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 46 #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" 47 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" 48 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 49 #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" 50 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 51 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" 52 #include "tensorflow/compiler/xla/service/name_uniquer.h" 53 #include "tensorflow/compiler/xla/statusor.h" 54 #include "tensorflow/compiler/xla/types.h" 55 #include "tensorflow/compiler/xla/xla_data.pb.h" 56 57 namespace xla { 58 namespace cpu { 59 // This class is the top-level API for the XLA HLO --> LLVM IR compiler. It 60 // implements the DfsHloVisitor interface and emits HLO computations as LLVM IR 61 // functions. 62 class IrEmitter : public DfsHloVisitorWithDefault, 63 public IrBuilderMixin<IrEmitter> { 64 friend class CpuElementalIrEmitter; 65 66 public: 67 using GeneratorForOperandIrArrays = 68 std::function<std::vector<llvm_ir::IrArray>()>; 69 70 // Create a new LLVM IR emitter. 71 // 72 // hlo_module: the HLO module we are emitting IR for. 73 // assignment: a BufferAssignment from which we know which buffers are used by 74 // the HLO nodes. 75 // mlir_context: the MLIR context used for IR emission. 76 // llvm_module: the LLVM module to emit IR into. It's built using the LLVM 77 // context inside of mlir_context. 78 // instruction_to_profile_idx: the mapping from HLO instructions to their 79 // index in the profiling array. 80 // computation_to_profile_idx: the mapping from HLO computations to their 81 // index in the profiling array. 82 // computation_transitively_contains_custom_call: the mapping from HLO 83 // computations to whether or not they transitively contain a custom-call 84 // instruction. All computations in the module must have a key in this 85 // map. 86 // emit_code_for_msan: whether emitted code should be compatible with msan. 87 IrEmitter(mlir::MLIRContext* mlir_context, const HloModule& hlo_module, 88 const BufferAssignment& assignment, llvm::Module* llvm_module, 89 absl::flat_hash_map<const HloInstruction*, int64_t> 90 instruction_to_profile_idx, 91 absl::flat_hash_map<const HloComputation*, int64_t> 92 computation_to_profile_idx, 93 absl::flat_hash_map<const HloComputation*, bool> 94 computation_transitively_contains_custom_call, 95 const TargetMachineFeatures* target_machine, 96 bool emit_code_for_msan); 97 ~IrEmitter() override; 98 99 // Emit and return the given HLO computation as an LLVM IR 100 // function. 101 // 102 // function_name_prefix is the desired name of the function. If the name is 103 // not unique among already emitted functions then a suffix is appended to 104 // make the name unique. 105 // 106 // 'is_top_level_computation' has the following meanings for each CPU backend: 107 // *) sequential: indicates that this is the entry computation of the HLO 108 // module. 109 // *) parallel: indices that this is the callee of a kCall HLO in the entry 110 // computation of the HLO module. 111 // 112 // If 'instruction_order' is not NULL, then the HLO instructions are emitted 113 // in the given order. In this case, 'instruction_order' must be a 114 // topological sort of the set of nodes accessible from the root of the 115 // computation. 116 // 117 // If 'allow_reassociation' is true, the fast-math reassociation flag will 118 // be enabled in the function's body. This is used when emitting reducers. 119 StatusOr<llvm::Function*> EmitComputation( 120 HloComputation* computation, const std::string& function_name_prefix, 121 bool is_top_level_computation, 122 absl::Span<HloInstruction* const> instruction_order, 123 bool allow_reassociation); 124 b()125 llvm::IRBuilder<>* b() { return &b_; } 126 127 // builder() is for IrBuilderMixin. builder()128 llvm::IRBuilder<>* builder() { return &b_; } 129 130 // Emit an LLVM global variable for every constant buffer allocation. 131 Status EmitConstantGlobals(); 132 133 protected: 134 // 135 // The following methods implement the DfsHloVisitor interface. 136 // 137 // Default action which emits code for most operations. Operations which are 138 // special in some way are handled explicitly in HandleFoo methods. 139 Status DefaultAction(HloInstruction* hlo) override; 140 141 Status HandleAllToAll(HloInstruction* instruction) override; 142 Status HandleBitcast(HloInstruction* bitcast) override; 143 Status HandleConstant(HloInstruction* constant) override; 144 Status HandleCopy(HloInstruction* copy) override; 145 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 146 Status HandleSelect(HloInstruction* select) override; 147 Status HandleDot(HloInstruction* dot) override; 148 Status HandleConvolution(HloInstruction* convolution) override; 149 Status HandleFft(HloInstruction* fft) override; 150 Status HandleAllReduce(HloInstruction* crs) override; 151 Status HandleCollectivePermute(HloInstruction* crs) override; 152 Status HandleInfeed(HloInstruction* infeed) override; 153 Status HandleOutfeed(HloInstruction* outfeed) override; 154 Status HandleSort(HloInstruction* hlo) override; 155 Status HandleParameter(HloInstruction* parameter) override; 156 Status HandleReduce(HloInstruction* reduce) override; 157 Status HandleReduceWindow(HloInstruction* reduce_window) override; 158 Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override; 159 Status HandleSend(HloInstruction* send) override; 160 Status HandleSendDone(HloInstruction* send_done) override; 161 Status HandleSlice(HloInstruction* slice) override; 162 Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; 163 Status HandleDynamicUpdateSlice( 164 HloInstruction* dynamic_update_slice) override; 165 Status HandleRecv(HloInstruction* recv) override; 166 Status HandleRecvDone(HloInstruction* recv_done) override; 167 Status HandlePad(HloInstruction* pad) override; 168 Status HandleTuple(HloInstruction* tuple) override; 169 Status HandleFusion(HloInstruction* fusion) override; 170 Status HandleCall(HloInstruction* call) override; 171 Status HandleCustomCall(HloInstruction* custom_call) override; 172 Status HandleWhile(HloInstruction* xla_while) override; 173 Status HandleConcatenate(HloInstruction* concatenate) override; 174 Status HandleConditional(HloInstruction* conditional) override; 175 Status HandleScatter(HloInstruction* scatter) override; 176 Status HandleAfterAll(HloInstruction* after_all) override; 177 Status HandleAddDependency(HloInstruction* add_dependency) override; 178 Status HandlePartitionId(HloInstruction* hlo) override; 179 Status HandleReplicaId(HloInstruction* hlo) override; 180 Status HandleRng(HloInstruction* rng) override; 181 Status HandleRngGetAndUpdateState(HloInstruction* rng_state) override; 182 Status FinishVisit(HloInstruction* root) override; 183 184 Status Preprocess(HloInstruction* hlo) override; 185 Status Postprocess(HloInstruction* hlo) override; 186 187 // A convenient helper for calling BufferAssignment::GetUniqueSlice. 188 BufferAllocation::Slice GetAllocationSlice( 189 const HloInstruction& hlo, const ShapeIndex& index = {}) const { 190 return assignment_.GetUniqueSlice(&hlo, index).value(); 191 } 192 193 private: 194 Status HandleSliceToDynamic(HloInstruction* hlo); 195 Status HandlePadToStatic(HloInstruction* hlo); 196 Status HandleTopK(HloInstruction* hlo); 197 Status HandleAllReduceSingleReplica(HloInstruction* crs); 198 Status HandleAllReduceMultipleReplica(HloInstruction* crs); 199 200 // Private helper to initialize an IR function for the computation. 201 void InitializeIrFunction(const std::string& function_name); 202 203 // Emits the copying epilogue for the function, 204 // where it copies the returned value to the reserved alloca. 205 // This is only necessary for thread-local functions. 206 // Note that since the call graph is flattened, if the same function is 207 // called in both thread-local and non-thread-local it would be codegen'd 208 // twice, and we would know whether it's thread-local at codegen time. 209 void EmitThreadLocalFunctionEpilogue(HloComputation* computation); 210 211 // Convenience functions to generate a GEP into the profile counter parameter 212 // which would correspond to the index for a given HLO instruction or 213 // computation. 214 llvm::Value* GetProfileCounterFor(const HloInstruction& instruction); 215 llvm::Value* GetProfileCounterFor(const HloComputation& computation); 216 217 // Helper function template for the implementation of the above two functions. 218 template <typename T> 219 llvm::Value* GetProfileCounterCommon( 220 const T& hlo, 221 const absl::flat_hash_map<const T*, int64_t>& profile_index_map); 222 223 // Gets the IR Value emitted previously for the given hlo. 224 // 225 // Prefer calling GetIrArrayFor if the value you're reading is a buffer, 226 // because GetIrArrayFor annotates buffer's loads/stores with noalias 227 // metadata. 228 // 229 // Make sure to call this only when you're certain a value *was* emitted - if 230 // not found, this will log a fatal error. 231 llvm::Value* GetEmittedValueFor(const HloInstruction* hlo); 232 233 // Gets an IrArray representing the given hlo. 234 llvm_ir::IrArray GetIrArrayFor(const HloInstruction* hlo); 235 236 // Gets a list of IrArrays, one for each of hlo's operands. 237 std::vector<llvm_ir::IrArray> GetIrArraysForOperandsOf( 238 const HloInstruction* hlo); 239 240 // Bind all argument IrArrays of `fusion` to `fused_emitter`. 241 void BindFusionArguments(const HloInstruction* fusion, 242 FusedIrEmitter* fused_emitter); 243 244 // Augments IrArray with aliasing information. AddAliasingInformationToIrArray(const HloInstruction & hlo,llvm_ir::IrArray * array)245 void AddAliasingInformationToIrArray(const HloInstruction& hlo, 246 llvm_ir::IrArray* array) { 247 alias_analysis_.AddAliasingInformationToIrArray(hlo, array); 248 } 249 250 // Convenience function to get the IR type matching the given shape. 251 llvm::Type* IrShapeType(const Shape& shape); 252 253 // Get the llvm::Value* that represents the "prof_counters" argument of the 254 // computation function being emitted by this emitter. 255 llvm::Value* GetProfileCountersArgument(); 256 257 // Get the llvm::Value* that represents the "status" argument of the 258 // computation function being emitted by this emitter. 259 llvm::Value* GetStatusArgument(); 260 261 // Get the xla::ExecutableRunOptions that represents the "run_options" 262 // argument of the computation function being emitted by this emitter. 263 llvm::Value* GetExecutableRunOptionsArgument(); 264 265 // Get the llvm::Value* that represents the "buffer_table" argument of the 266 // computation function being emitted by this emitter. 267 llvm::Value* GetBufferTableArgument(); 268 269 // Get the llvm::BasicBlock that contains the return instruction. 270 llvm::BasicBlock* GetReturnBlock(); 271 272 // Emits code to check the state of the status object being threaded through 273 // each computation and return early if it's in an error state. 274 void EmitEarlyReturnIfErrorStatus(); 275 276 // Helper for EmitBufferPointer. 277 llvm::Value* EmitGlobalBufferPointer(const BufferAllocation::Slice& slice, 278 const Shape& target_shape); 279 280 // Helper for EmitBufferPointer. 281 llvm::Value* EmitThreadLocalBufferPointer( 282 const BufferAllocation::Slice& slice, const Shape& target_shape); 283 284 // Emits code that computes the address of the given buffer allocation slice. 285 llvm::Value* EmitBufferPointer(const BufferAllocation::Slice& slice, 286 const Shape& target_shape); 287 288 // Emits a call to a thread local function (e.g. to the computation nested 289 // within a reduce or a map). Thread local callees (by definition) only write 290 // to and read from thread local allocations. 291 // Supports only functions returning scalars or tuples of scalars. 292 // 293 // `parameters` holds the *scalar values* that need to be passed to the 294 // callee. The return value is the scalar returned by the callee. 295 std::vector<llvm::Value*> EmitThreadLocalCall( 296 const HloComputation& callee, absl::Span<llvm::Value* const> parameters, 297 absl::string_view name, bool is_reducer); 298 299 // Similar to EmitThreadLocal, yet assumes that the function returns a scalar. 300 llvm::Value* EmitScalarReturningThreadLocalCall( 301 const HloComputation& callee, absl::Span<llvm::Value* const> parameters, 302 absl::string_view name); 303 304 // Emits a call to a "global" function (e.g. to the computation nested within 305 // a kWhile or a kCall). Buffer assignment unabiguously assigns buffers to 306 // the parameters and return values for these computations so there is no need 307 // to explicitly pass parameters or return results. 308 void EmitGlobalCall(const HloComputation& callee, absl::string_view name); 309 310 // Returns the buffer to which a global call to `callee` would have written 311 // its result. 312 llvm::Value* GetBufferForGlobalCallReturnValue(const HloComputation& callee); 313 314 // Verifies that the element types of all of the given operand instructions 315 // match and are of one of the given supported types. 316 Status ElementTypesSameAndSupported( 317 const HloInstruction& instruction, 318 absl::Span<const HloInstruction* const> operands, 319 absl::Span<const PrimitiveType> supported_types); 320 321 // Emit IR to perform a computation for every element in the given target op. 322 // This produces a series of nested loops (one for each dimension of the op's 323 // shape). The body of the inner-most loop is provided by the body_emitter 324 // function. 325 // 326 // desc is an optional human-readable string that's added to the loop name in 327 // IR. Regardless of whether desc is provided, target_op->name() is included 328 // in the loop name. 329 // 330 // TODO(jingyue): target_op should be a `const HloInstruction*`. 331 Status EmitTargetElementLoop( 332 HloInstruction* target_op, 333 const llvm_ir::ElementGenerator& element_generator); 334 Status EmitTargetElementLoop( 335 HloInstruction* target_op, absl::string_view desc, 336 const llvm_ir::ElementGenerator& element_generator); 337 338 // Emits a memcpy from the source instruction's result value to the 339 // destination's. Both source and destination must have an entry in the 340 // emitted_value_ table. 341 Status EmitMemcpy(const HloInstruction& source, 342 const HloInstruction& destination); 343 344 // Emits IR to compute the target address of the buffer for the given op. 345 // After calling this function, you can get a pointer to this buffer by 346 // calling GetIrArrayForOp or GetEmittedValueFor. 347 Status EmitTargetAddressForOp(const HloInstruction* op); 348 349 // Structurizes "array_elements" into an MD array that represents "shape". 350 // This is a recursive function, and "dimension_index" indicates the index of 351 // the current dimension that the function is considering (0 means the 352 // most-minor dimension). 353 llvm::Constant* CreateInitializerForConstantArray( 354 const std::vector<llvm::Constant*>& array_elements, const Shape& shape, 355 int64_t dimension_index); 356 357 // Tries to codegen a reduction operation using vectorized instructions. 358 // Returns true if successful, and false on failure. On failure, sets 359 // "failure_reason" to a string describing why it could not vectorize the 360 // reduction. 361 // 362 // TODO(sanjoy): Some of the things we do here can be abstracted out into 363 // concepts that generalize over other vectorizable operations. We should 364 // consider pulling out these abstractions into a VectorizingIrEmitter or 365 // something similar. 366 StatusOr<bool> EmitVectorizedReduce(HloInstruction* reduce, 367 HloInstruction* arg, 368 HloInstruction* init_value, 369 absl::Span<const int64_t> dimensions, 370 HloComputation* function, 371 std::string* failure_reason); 372 373 // We'd like to keep one or two one cache-line's worth of data in registers 374 // without generating IR with illegal (e.g. excessively large or 375 // non-power-of-two) vector types. We do this by introducing a layer of 376 // abstraction: we introduce a high level vector-like concept called a 377 // "sharded vector" that models data parallelism, and is mapped to a sequence 378 // scalar and vector llvm::Value s. 379 // 380 // For example, we can represent 29 f32 elements by a sharded vector mapped to 381 // a sequence of LLVM values of types [<16 x f32>, <8 x f32>, <4 x f32>, f32]. 382 // Note that the last element is scalar. 383 // 384 // There is no requirement on the ordering or the uniqueness of the elements 385 // mapped to sharded vectors -- we allow repeated elements, and we allow 386 // elements to appear in any order. 387 using ShardedVector = std::vector<llvm::Value*>; 388 389 // A sharded vector type is the element-wise llvm::Type's of some 390 // ShardedVector. 391 using ShardedVectorType = std::vector<llvm::Type*>; 392 393 // Create a sharded vector type corresponding to a "element_count" long 394 // sequence of "element_type" values. 395 ShardedVectorType CreateShardedVectorType(PrimitiveType element_type, 396 unsigned element_count); 397 398 // Emit LLVM IR to store the sharded vector "value_to_store" to 399 // "store_address". 400 void EmitShardedVectorStore(llvm::Value* store_address, 401 const ShardedVector& value_to_store, 402 llvm::Align alignment, 403 const llvm_ir::IrArray& containing_array); 404 405 using ReductionGenerator = std ::function<llvm::Value*( 406 llvm::IRBuilder<>*, llvm::Value*, llvm::Value*)>; 407 408 // Tries to match the reduction function "function" to a known reduction 409 // pattern. Returns a non-null ReductionGenerator on a successful match, 410 // which can be used to generate the LLVM IR corresponding to said reduction. 411 // On failure, this stores a reason string into "failure_reason". 412 ReductionGenerator MatchReductionGenerator(HloComputation* function, 413 std::string* failure_reason) const; 414 415 // Emits the inner loop nest that runs the reduction. Helper function for 416 // EmitVectorizedReduce. 417 StatusOr<ShardedVector> EmitInnerLoopForVectorizedReduction( 418 const ReductionGenerator& reduction_generator, 419 const llvm_ir::IrArray::Index& output_index, 420 const ShardedVectorType& accumulator_type, HloInstruction* init_value, 421 HloInstruction* arg, absl::Span<const int64_t> dimensions, 422 llvm::Align element_alignment); 423 424 // Tries to emit a fast concatenate operation using memcpy. Returns true if 425 // successful, and false on failure. On failure, sets "failure_reason" to a 426 // string describing why it could not emit a fast concatenate. 427 StatusOr<bool> EmitFastConcatenate(HloInstruction* concatenate, 428 absl::Span<HloInstruction* const> operands, 429 std::string* failure_reason); 430 431 // Emits LLVM IR to transfer "element_count" elements of type "primitive_type" 432 // from the address "source" to the address "target". 433 void EmitTransferElements(llvm::Value* target, llvm::Value* source, 434 int64_t element_count, PrimitiveType primitive_type, 435 const llvm_ir::IrArray& target_array, 436 const llvm_ir::IrArray& source_array); 437 438 // Emits printing during the execution. 439 llvm::Value* EmitPrintf(absl::string_view fmt, 440 absl::Span<llvm::Value* const> arguments); 441 llvm::Value* EmitPrintfToStderr(absl::string_view fmt, 442 absl::Span<llvm::Value* const> arguments); 443 444 // Emits a call to a non-variadic function `func_name` with arguments 445 // `arguments` assuming C calling convention. 446 llvm::Value* EmitCallToFunc( 447 std::string func_name, const std::vector<llvm::Value*>& arguments, 448 llvm::Type* return_type, bool does_not_throw = true, 449 bool only_accesses_arg_memory = false, 450 bool only_accesses_inaccessible_mem_or_arg_mem = false); 451 452 // Assignment of the buffers needed by the computation and their shape 453 // information. 454 const BufferAssignment& assignment_; 455 456 // The LLVM module into which IR will be emitted. 457 llvm::Module* module_; 458 459 // The target architecture. 460 llvm::Triple::ArchType arch_type_; 461 462 // Used to produce unique names for generated functions. 463 NameUniquer name_uniquer_; 464 465 struct ComputationToEmit { 466 const HloComputation* computation; 467 bool allow_reassociation; 468 469 bool operator==(const ComputationToEmit& other) const { 470 return computation == other.computation && 471 allow_reassociation == other.allow_reassociation; 472 } 473 474 template <typename H> AbslHashValueComputationToEmit475 friend H AbslHashValue(H h, const ComputationToEmit& c) { 476 return H::combine(std::move(h), c.computation, c.allow_reassociation); 477 } 478 friend std::ostream& operator<<(std::ostream& os, 479 const ComputationToEmit& c) { 480 return os << c.computation->name() << ", " << c.allow_reassociation; 481 } 482 }; 483 484 // Map containing all previously emitted computations. 485 absl::flat_hash_map<ComputationToEmit, llvm::Function*> emitted_functions_; 486 487 // Map containing all previously emitted thread-local temporary buffers. 488 std::map<std::pair<llvm::Function*, BufferAllocation::Slice>, llvm::Value*> 489 thread_local_buffers_; 490 491 // The following fields track the IR emission state. According to LLVM memory 492 // management rules, their memory is owned by the module (Note that IrFunction 493 // creates the encapsulated llvm::Function s.t. it is added to the llvm 494 // module's function list). 495 std::unique_ptr<IrFunction> compute_function_; 496 llvm::IRBuilder<> b_; 497 mlir::MLIRContext* mlir_context_; 498 bool allow_reassociation_; 499 500 // The buffer allocation slice for the root of the computation being compiled. 501 // Only relevant for thread local computations. 502 BufferAllocation::Slice computation_root_allocation_; 503 504 // Maps the buffer allocation slices for the parameters to the computation 505 // being compiled to their parameter numbers. Only relevant for thread local 506 // computations. 507 absl::flat_hash_map<BufferAllocation::Index, int64_t> 508 computation_parameter_allocations_; 509 510 // Maps HLO instructions to their index into the profile counter array. 511 const absl::flat_hash_map<const HloInstruction*, int64_t> 512 instruction_to_profile_idx_; 513 514 // Maps HLO computations to their index into the profile counter array. 515 const absl::flat_hash_map<const HloComputation*, int64_t> 516 computation_to_profile_idx_; 517 518 // Maps HLO computations to whether they contain a custom-call instruction 519 // (either directly, or transitively by e.g. calling another computation that 520 // does). 521 const absl::flat_hash_map<const HloComputation*, bool> 522 computation_transitively_contains_custom_call_; 523 524 // Accessor for the custom-call mapping that enforces the precondition that 525 // all computations must have a key in the map. ComputationTransitivelyContainsCustomCall(const HloComputation * computation)526 bool ComputationTransitivelyContainsCustomCall( 527 const HloComputation* computation) const { 528 auto it = computation_transitively_contains_custom_call_.find(computation); 529 CHECK(it != computation_transitively_contains_custom_call_.cend()) 530 << "Must provide 'contains CustomCall' annotation for all computations " 531 "in the module"; 532 return it->second; 533 } 534 535 // Maps HLOs to Values emitted for them. 536 absl::flat_hash_map<const HloInstruction*, llvm::Value*> emitted_value_; 537 538 llvm_ir::AliasAnalysis alias_analysis_; 539 540 // The number of outer dimensions of the root instruction's shape that 541 // will be partitioned when emitting parallel loops. (See 542 // ParallelLoopEmitter). 543 int64_t num_dynamic_loop_bounds_ = 0; 544 545 // Returns whether the given instruction should be emitted as a parallel loop. ShouldEmitParallelLoopFor(const HloInstruction & op)546 bool ShouldEmitParallelLoopFor(const HloInstruction& op) const { 547 // Emit parallel loop for root instruction if dynamic outer-dimension loop 548 // bounds were specified. 549 return num_dynamic_loop_bounds_ > 0 && 550 op.parent()->root_instruction() == &op; 551 } 552 553 // This struct contains all the state needed to emit instructions for 554 // profiling a computation. 555 class ProfilingState { 556 public: ProfilingState()557 ProfilingState() : use_rdtscp_(false) {} ProfilingState(bool use_rdtscp)558 explicit ProfilingState(bool use_rdtscp) : use_rdtscp_(use_rdtscp) {} 559 560 // Record the cycle counter before an HLO executes. 561 void RecordCycleStart(llvm::IRBuilder<>* b, HloInstruction* hlo); 562 // Record the number of cycles it took for an HLO to execute. 563 void RecordCycleDelta(llvm::IRBuilder<>* b, HloInstruction* hlo, 564 llvm::Value* prof_counter); 565 // Record the number of cycles it took for the entire computation to 566 // execute. 567 void RecordCompleteComputation(llvm::IRBuilder<>* b, 568 llvm::Value* prof_counter); 569 570 // Convenience function to generate a call to an intrinsic which reads the 571 // CPU cycle counter. 572 llvm::Value* ReadCycleCounter(llvm::IRBuilder<>* b); 573 574 // Store the cycle counter delta to the per-HLO profile counter. 575 void UpdateProfileCounter(llvm::IRBuilder<>* b, llvm::Value* prof_counter, 576 llvm::Value* cycle_end, llvm::Value* cycle_start); 577 578 private: 579 // Should we use the x86-specific rdtscp or the generic readcyclecounter 580 // intrinsic? 581 bool use_rdtscp_; 582 583 // The first read cycle counter in the program. 584 llvm::Value* first_read_cycle_start_ = nullptr; 585 586 // The last read cycle counter in the program. 587 llvm::Value* last_read_cycle_end_ = nullptr; 588 589 // Maps HLOs to the value the cycle counter contained right before the HLO 590 // began to execute. 591 absl::flat_hash_map<const HloInstruction*, llvm::Value*> cycle_starts_; 592 }; 593 594 ProfilingState profiling_state_; 595 596 class TracingState { 597 public: TracingState()598 TracingState() : enabled_(false) {} set_enabled(bool value)599 void set_enabled(bool value) { enabled_ = value; } 600 void EmitTracingStart(llvm::IRBuilder<>* b, HloInstruction* hlo, 601 llvm::Value* run_options); 602 void EmitTracingEnd(llvm::IRBuilder<>* b, HloInstruction* hlo, 603 llvm::Value* run_options); 604 605 private: 606 bool enabled_; 607 // Maps from HLO to the activity id returned by xprof::TraceMe. 608 absl::flat_hash_map<const HloInstruction*, llvm::Value*> activity_ids_; 609 }; 610 TracingState tracing_state_; 611 612 // Given a load instruction and a shape or buffer size, annotate the load's 613 // result with the alignment required by the shape or size. 614 void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, const Shape& shape); 615 void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, 616 int64_t buffer_size); 617 618 // Given a load instruction and a shape or buffer size, annotate the load's 619 // result with the dereferenceable bytes required by the shape / buffer size. 620 void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load, 621 const Shape& shape); 622 void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load, 623 int64_t buffer_size); 624 625 // Calculate the alignment of a buffer allocated for a given shape. 626 int MinimumAlignmentForShape(const Shape& shape); 627 628 // Calculate the alignment of a buffer allocated for a given primitive type. 629 int MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type); 630 631 // Returns the number of bytes within the shape. 632 int64_t ByteSizeOf(const Shape& shape) const; 633 634 enum class XfeedKind { 635 kInfeed, 636 kOutfeed, 637 }; 638 639 // Emit IR to transfer between a {infeed,outfeed} buffer and an in-program 640 // address. 641 Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape, 642 llvm::Value* program_buffer_address); 643 644 // Returns a ConstExpr bitcast. 645 llvm::Constant* EmitGlobalForLiteral(const Literal& literal); 646 647 const HloModuleConfig& hlo_module_config_; 648 649 bool is_top_level_computation_; 650 651 const TargetMachineFeatures& target_machine_features_; 652 653 struct LiteralPtrHashFunctor { operatorLiteralPtrHashFunctor654 size_t operator()(const Literal* literal) const { 655 return absl::HashOf(*literal); 656 } 657 }; 658 659 struct LiteralPtrEqualityFunctor { operatorLiteralPtrEqualityFunctor660 bool operator()(const Literal* lhs, const Literal* rhs) const { 661 return *lhs == *rhs; 662 } 663 }; 664 665 absl::flat_hash_map<const Literal*, llvm::Constant*, LiteralPtrHashFunctor, 666 LiteralPtrEqualityFunctor> 667 emitted_literals_; 668 669 absl::flat_hash_map<BufferAllocation::Index, llvm::Constant*> 670 constant_buffer_to_global_; 671 672 std::vector<const HloComputation*> thread_local_computations_; 673 std::vector<const HloComputation*> global_computations_; 674 675 bool emit_code_for_msan_; 676 677 IrEmitter(const IrEmitter&) = delete; 678 IrEmitter& operator=(const IrEmitter&) = delete; 679 }; 680 681 } // namespace cpu 682 } // namespace xla 683 684 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_ 685