1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_KERNEL_MAPPING_SCHEME_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_KERNEL_MAPPING_SCHEME_H_ 18 19 #include "absl/container/inlined_vector.h" 20 #include "absl/types/span.h" 21 #include "llvm/IR/Value.h" 22 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 23 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 24 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" 25 #include "tensorflow/compiler/xla/util.h" 26 27 namespace xla { 28 namespace gpu { 29 30 // Describes tiling used by the kernel. 31 // 32 // Used by reductions and 021 transpose algorithm. Both algorithms operate over 33 // "logical" 3D views over input arrays, hence tiling and number of threads 34 // information has only 3 dimensions. 35 // 36 // In the presence of virtual threadIdx/blockIdx scaling, all accessors are 37 // "logical", unless otherwise specified. 38 class TilingScheme { 39 public: 40 enum { DimZ = 0, DimY, DimX, DimTot }; 41 42 enum IndexingOrder { 43 // Thread reads consecutive elements. 44 LinearIndexingX, 45 // Thread reads strided elements while keeping memory coalescing. 46 StridedIndexingX, 47 }; 48 TilingScheme(Vector3 dims_in_elems,Vector3 tile_sizes,Vector3 num_threads,IndexingOrder indexing_order,int vector_size,int scaling_factor)49 TilingScheme(Vector3 dims_in_elems, Vector3 tile_sizes, Vector3 num_threads, 50 IndexingOrder indexing_order, int vector_size, 51 int scaling_factor) 52 : dims_in_elems_(dims_in_elems), 53 tile_sizes_(tile_sizes), 54 num_threads_(num_threads), 55 indexing_order_(indexing_order), 56 vector_size_(vector_size), 57 thread_id_virtual_scaling_(scaling_factor) { 58 CHECK_EQ(tile_sizes[2] % vector_size_, 0); 59 } 60 IndexingOrderToString(IndexingOrder order)61 static std::string IndexingOrderToString(IndexingOrder order) { 62 switch (order) { 63 case LinearIndexingX: 64 return "linear"; 65 case StridedIndexingX: 66 return "strided"; 67 } 68 } 69 ToString()70 std::string ToString() const { 71 return absl::StrJoin( 72 {absl::StrFormat("dims_in_elems = {%s}", 73 absl::StrJoin(dims_in_elems_, ", ")), 74 absl::StrFormat("tile_sizes = {%s}", absl::StrJoin(tile_sizes_, ", ")), 75 absl::StrFormat("num_threads = {%s}", 76 absl::StrJoin(num_threads_, ", ")), 77 absl::StrFormat("indexing_order = %s", 78 IndexingOrderToString(indexing_order_)), 79 absl::StrFormat("vector_size = %d", vector_size_), 80 absl::StrFormat("thread_id_virtual_scaling = %d", 81 thread_id_virtual_scaling_)}, 82 ", "); 83 } 84 85 // Number of elements in each dimension (Z/Y/X respectively). GetDimsInElems()86 absl::Span<const int64_t> GetDimsInElems() const { return dims_in_elems_; } 87 GetDimsInBlocks()88 Vector3 GetDimsInBlocks() const { 89 return {GetDimInBlock(0), GetDimInBlock(1), GetDimInBlock(2)}; 90 } 91 92 // Number of blocks required to "cover" the given dimension. GetDimInBlock(int d)93 int64_t GetDimInBlock(int d) const { 94 return CeilOfRatio(dims_in_elems_[d], GetBlockTileSizeFor(d)); 95 } 96 97 // Tile size for a given dimensions per thread. 98 // 99 // Equals to the number of iterations in the loop each tile will make. GetTileSizeFor(int d)100 int64_t GetTileSizeFor(int d) const { return tile_sizes_.at(d); } 101 102 // Tile size for a given dimension per entire thread block. GetBlockTileSizeFor(int d)103 int64_t GetBlockTileSizeFor(int d) const { 104 return num_threads_.at(d) * tile_sizes_.at(d); 105 } 106 107 // Number of threads in given dimension. GetNumThreadsFor(int d)108 int64_t GetNumThreadsFor(int d) const { return num_threads_.at(d); } 109 110 // Number of logical threads per block. GetNumThreadsPerBlock()111 int64_t GetNumThreadsPerBlock() const { 112 return GetNumThreadsFor(0) * GetNumThreadsFor(1) * GetNumThreadsFor(2); 113 } 114 115 // Number of logical blocks. GetNumberOfBlocks()116 int64_t GetNumberOfBlocks() const { 117 return GetDimInBlock(0) * GetDimInBlock(1) * GetDimInBlock(2); 118 } 119 120 // Number of physical blocks launched (with scaling applied). GetNumberOfBlocksPhysical()121 int64_t GetNumberOfBlocksPhysical() const { 122 return CeilOfRatio(GetNumberOfBlocks(), thread_id_virtual_scaling_); 123 } 124 125 // Number of physical threads per block launched (with scaling applied). GetNumThreadsPerBlockPhysical()126 int64_t GetNumThreadsPerBlockPhysical() const { 127 return GetNumThreadsPerBlock() * thread_id_virtual_scaling_; 128 } 129 GetIndexingOrder()130 IndexingOrder GetIndexingOrder() const { return indexing_order_; } GetVectorSize()131 int GetVectorSize() const { return vector_size_; } 132 133 // Scaling factor for transforming physical threadId to logical. GetThreadIdScalingFactor()134 int GetThreadIdScalingFactor() const { return thread_id_virtual_scaling_; } 135 136 private: 137 // The number of elements in each dimension. 138 const Vector3 dims_in_elems_; 139 140 // The number of elements for each dimension of a tile. 141 const Vector3 tile_sizes_; 142 143 // Number of threads implicitly assigned to each dimension. 144 const Vector3 num_threads_; 145 146 const IndexingOrder indexing_order_; 147 148 // Vector size for dimension X. 149 const int vector_size_; 150 151 // Scaling apply to transform physical threadIdx into logical. 152 const int64_t thread_id_virtual_scaling_ = 1; 153 }; 154 155 class ReductionCodegenInfo { 156 public: ReductionCodegenInfo(TilingScheme mapping_scheme,int num_partial_results,bool is_row_reduction,bool is_race_free)157 explicit ReductionCodegenInfo(TilingScheme mapping_scheme, 158 int num_partial_results, bool is_row_reduction, 159 bool is_race_free) 160 : tiling_scheme_(mapping_scheme), 161 num_partial_results_(num_partial_results), 162 is_row_reduction_(is_row_reduction), 163 is_race_free_(is_race_free) { 164 if (num_partial_results > 1) { 165 CHECK_EQ(num_partial_results, 166 mapping_scheme.GetTileSizeFor(TilingScheme::DimX)); 167 } 168 } 169 GetTilingScheme()170 const TilingScheme& GetTilingScheme() const { return tiling_scheme_; } 171 GetNumPartialResults()172 int GetNumPartialResults() const { return num_partial_results_; } IsRaceFree()173 bool IsRaceFree() const { return is_race_free_; } 174 175 private: 176 friend class ReductionCodegenState; 177 178 const TilingScheme tiling_scheme_; 179 int num_partial_results_; 180 bool is_row_reduction_; 181 bool is_race_free_; 182 }; 183 184 class ReductionCodegenState { 185 public: 186 struct ReductionCalculationState { 187 llvm::GlobalVariable* shared_cache; 188 llvm::Value* initial_value; 189 llvm::AllocaInst* partial_result_address; 190 llvm::AllocaInst* input_address; 191 llvm_ir::ElementGenerator input_gen; 192 }; 193 ReductionCodegenState(const ReductionCodegenInfo & reduction_codegen_info)194 explicit ReductionCodegenState( 195 const ReductionCodegenInfo& reduction_codegen_info) 196 : reduction_codegen_info_(reduction_codegen_info) {} 197 GetTilingScheme()198 const TilingScheme& GetTilingScheme() const { 199 return reduction_codegen_info_.tiling_scheme_; 200 } 201 GetNumPartialResults()202 int GetNumPartialResults() const { 203 return reduction_codegen_info_.num_partial_results_; 204 } 205 IsRowReduction()206 bool IsRowReduction() const { 207 return reduction_codegen_info_.is_row_reduction_; 208 } 209 IsRaceFree()210 bool IsRaceFree() const { return reduction_codegen_info_.IsRaceFree(); } 211 GetCalculationStateFor(const HloInstruction * instruction,int operand_idx)212 const ReductionCalculationState& GetCalculationStateFor( 213 const HloInstruction* instruction, int operand_idx) const { 214 const ReductionOpState& op_state = state_.at(instruction); 215 CHECK_LT(operand_idx, op_state.size()); 216 return op_state[operand_idx]; 217 } 218 SetCalculationStateFor(const ReductionCalculationState & calculation_state,const HloInstruction * instruction,int operand_idx)219 void SetCalculationStateFor( 220 const ReductionCalculationState& calculation_state, 221 const HloInstruction* instruction, int operand_idx) { 222 ReductionOpState& op_state = state_[instruction]; 223 CHECK_EQ(operand_idx, op_state.size()); 224 op_state.push_back(calculation_state); 225 } 226 227 private: 228 ReductionCodegenInfo reduction_codegen_info_; 229 230 // One state per reduction operand. 231 using ReductionOpState = absl::InlinedVector<ReductionCalculationState, 2>; 232 233 // HloInstruction -> operand_idx -> cache 234 absl::flat_hash_map<const HloInstruction*, ReductionOpState> state_; 235 }; 236 237 } // end namespace gpu 238 } // end namespace xla 239 240 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_KERNEL_MAPPING_SCHEME_H_ 241