1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PARALLEL_LOOP_EMITTER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PARALLEL_LOOP_EMITTER_H_ 18 19 #include "llvm/IR/IRBuilder.h" 20 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" 21 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 22 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" 23 24 namespace xla { 25 namespace gpu { 26 27 // Emits a parallel loop for every element in the given array shape. This loop 28 // emitted will be executed by multiple threads in parallel. Therefore, each 29 // thread instance of the loop iterates over part of the array, and they 30 // collectively iterates over the entire array. 31 class ParallelLoopEmitter { 32 public: 33 // `launch_dimensions` specify the number of threads and blocks to 34 // parallelize the loop on. `launch_config` specify some detail on 35 // how to parallelize. 36 ParallelLoopEmitter(llvm_ir::BodyEmitter body_emitter, const Shape& shape, 37 const LaunchDimensions& launch_dimensions, 38 llvm::IRBuilder<>* b, 39 LaunchDimensionsConfig launch_config = {}); 40 41 // Constructs a loop emitter for a loop that generates on element of each of N 42 // arrays on each iteration. 43 // 44 // This is used in multi-output fusion. target_element_generator should 45 // produce a struct with N elements, one for each of target_arrays. 46 ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator, 47 absl::Span<const llvm_ir::IrArray> target_arrays, 48 const LaunchDimensions& launch_dimensions, 49 llvm::IRBuilder<>* b, 50 LaunchDimensionsConfig launch_config = {}); 51 52 ParallelLoopEmitter(const ParallelLoopEmitter&) = delete; 53 ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete; 54 55 std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock( 56 absl::string_view loop_name, llvm::Type* index_type, 57 llvm::Value* base_index); 58 59 // This is similar to EmitIndexAndSetExitBasicBlock except that we 60 // change the mapping of threads to output buffer index such that adjacent 61 // threads write to logically adjacent index in output buffer instead of 62 // physically adjacent index. 63 std::vector<llvm_ir::IrArray::Index> EmitLogicalIndexAndSetExitBasicBlock( 64 absl::string_view loop_name, llvm::Type* index_type, 65 llvm::Value* base_index); 66 67 Status EmitLoop(absl::string_view loop_name = "", 68 llvm::Type* index_type = nullptr); 69 70 private: 71 struct LinearBaseAndThreadIdx { 72 llvm::Value* linear_base; 73 llvm::Value* thread_idx; 74 }; 75 76 LinearBaseAndThreadIdx EmitLinearBaseAndThreadIdx(llvm::Type* index_type, 77 llvm::Value* base_index); 78 Status EmitSerialLoop(absl::string_view loop_name, llvm::Type* index_type, 79 llvm::Value* base_indvar = nullptr); 80 81 // The thread and block dimension to parallelize the loop on. 82 const LaunchDimensions launch_dimensions_; 83 const LaunchDimensionsConfig launch_config_; 84 85 // An IR emitter that generates the loop body. 86 llvm_ir::BodyEmitter body_emitter_; 87 88 // The shape that the emitted loop iterates through. 89 Shape shape_; 90 91 // Points to the exit block of the emitted loop. If the given shape is 92 // scalar, no loops are emitted and exit_bb_ is nullptr in that case. 93 llvm::BasicBlock* exit_bb_; 94 95 llvm::IRBuilder<>* b_; 96 }; 97 98 } // namespace gpu 99 } // namespace xla 100 101 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PARALLEL_LOOP_EMITTER_H_ 102