xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_KERNEL_MAPPING_SCHEME_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_KERNEL_MAPPING_SCHEME_H_
18 
19 #include "absl/container/inlined_vector.h"
20 #include "absl/types/span.h"
21 #include "llvm/IR/Value.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
24 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
25 #include "tensorflow/compiler/xla/util.h"
26 
27 namespace xla {
28 namespace gpu {
29 
30 // Describes tiling used by the kernel.
31 //
32 // Used by reductions and 021 transpose algorithm. Both algorithms operate over
33 // "logical" 3D views over input arrays, hence tiling and number of threads
34 // information has only 3 dimensions.
35 //
36 // In the presence of virtual threadIdx/blockIdx scaling, all accessors are
37 // "logical", unless otherwise specified.
38 class TilingScheme {
39  public:
40   enum { DimZ = 0, DimY, DimX, DimTot };
41 
42   enum IndexingOrder {
43     // Thread reads consecutive elements.
44     LinearIndexingX,
45     // Thread reads strided elements while keeping memory coalescing.
46     StridedIndexingX,
47   };
48 
TilingScheme(Vector3 dims_in_elems,Vector3 tile_sizes,Vector3 num_threads,IndexingOrder indexing_order,int vector_size,int scaling_factor)49   TilingScheme(Vector3 dims_in_elems, Vector3 tile_sizes, Vector3 num_threads,
50                IndexingOrder indexing_order, int vector_size,
51                int scaling_factor)
52       : dims_in_elems_(dims_in_elems),
53         tile_sizes_(tile_sizes),
54         num_threads_(num_threads),
55         indexing_order_(indexing_order),
56         vector_size_(vector_size),
57         thread_id_virtual_scaling_(scaling_factor) {
58     CHECK_EQ(tile_sizes[2] % vector_size_, 0);
59   }
60 
IndexingOrderToString(IndexingOrder order)61   static std::string IndexingOrderToString(IndexingOrder order) {
62     switch (order) {
63       case LinearIndexingX:
64         return "linear";
65       case StridedIndexingX:
66         return "strided";
67     }
68   }
69 
ToString()70   std::string ToString() const {
71     return absl::StrJoin(
72         {absl::StrFormat("dims_in_elems = {%s}",
73                          absl::StrJoin(dims_in_elems_, ", ")),
74          absl::StrFormat("tile_sizes = {%s}", absl::StrJoin(tile_sizes_, ", ")),
75          absl::StrFormat("num_threads = {%s}",
76                          absl::StrJoin(num_threads_, ", ")),
77          absl::StrFormat("indexing_order = %s",
78                          IndexingOrderToString(indexing_order_)),
79          absl::StrFormat("vector_size = %d", vector_size_),
80          absl::StrFormat("thread_id_virtual_scaling = %d",
81                          thread_id_virtual_scaling_)},
82         ", ");
83   }
84 
85   // Number of elements in each dimension (Z/Y/X respectively).
GetDimsInElems()86   absl::Span<const int64_t> GetDimsInElems() const { return dims_in_elems_; }
87 
GetDimsInBlocks()88   Vector3 GetDimsInBlocks() const {
89     return {GetDimInBlock(0), GetDimInBlock(1), GetDimInBlock(2)};
90   }
91 
92   // Number of blocks required to "cover" the given dimension.
GetDimInBlock(int d)93   int64_t GetDimInBlock(int d) const {
94     return CeilOfRatio(dims_in_elems_[d], GetBlockTileSizeFor(d));
95   }
96 
97   // Tile size for a given dimensions per thread.
98   //
99   // Equals to the number of iterations in the loop each tile will make.
GetTileSizeFor(int d)100   int64_t GetTileSizeFor(int d) const { return tile_sizes_.at(d); }
101 
102   // Tile size for a given dimension per entire thread block.
GetBlockTileSizeFor(int d)103   int64_t GetBlockTileSizeFor(int d) const {
104     return num_threads_.at(d) * tile_sizes_.at(d);
105   }
106 
107   // Number of threads in given dimension.
GetNumThreadsFor(int d)108   int64_t GetNumThreadsFor(int d) const { return num_threads_.at(d); }
109 
110   // Number of logical threads per block.
GetNumThreadsPerBlock()111   int64_t GetNumThreadsPerBlock() const {
112     return GetNumThreadsFor(0) * GetNumThreadsFor(1) * GetNumThreadsFor(2);
113   }
114 
115   // Number of logical blocks.
GetNumberOfBlocks()116   int64_t GetNumberOfBlocks() const {
117     return GetDimInBlock(0) * GetDimInBlock(1) * GetDimInBlock(2);
118   }
119 
120   // Number of physical blocks launched (with scaling applied).
GetNumberOfBlocksPhysical()121   int64_t GetNumberOfBlocksPhysical() const {
122     return CeilOfRatio(GetNumberOfBlocks(), thread_id_virtual_scaling_);
123   }
124 
125   // Number of physical threads per block launched (with scaling applied).
GetNumThreadsPerBlockPhysical()126   int64_t GetNumThreadsPerBlockPhysical() const {
127     return GetNumThreadsPerBlock() * thread_id_virtual_scaling_;
128   }
129 
GetIndexingOrder()130   IndexingOrder GetIndexingOrder() const { return indexing_order_; }
GetVectorSize()131   int GetVectorSize() const { return vector_size_; }
132 
133   // Scaling factor for transforming physical threadId to logical.
GetThreadIdScalingFactor()134   int GetThreadIdScalingFactor() const { return thread_id_virtual_scaling_; }
135 
136  private:
137   // The number of elements in each dimension.
138   const Vector3 dims_in_elems_;
139 
140   // The number of elements for each dimension of a tile.
141   const Vector3 tile_sizes_;
142 
143   // Number of threads implicitly assigned to each dimension.
144   const Vector3 num_threads_;
145 
146   const IndexingOrder indexing_order_;
147 
148   // Vector size for dimension X.
149   const int vector_size_;
150 
151   // Scaling apply to transform physical threadIdx into logical.
152   const int64_t thread_id_virtual_scaling_ = 1;
153 };
154 
155 class ReductionCodegenInfo {
156  public:
ReductionCodegenInfo(TilingScheme mapping_scheme,int num_partial_results,bool is_row_reduction,bool is_race_free)157   explicit ReductionCodegenInfo(TilingScheme mapping_scheme,
158                                 int num_partial_results, bool is_row_reduction,
159                                 bool is_race_free)
160       : tiling_scheme_(mapping_scheme),
161         num_partial_results_(num_partial_results),
162         is_row_reduction_(is_row_reduction),
163         is_race_free_(is_race_free) {
164     if (num_partial_results > 1) {
165       CHECK_EQ(num_partial_results,
166                mapping_scheme.GetTileSizeFor(TilingScheme::DimX));
167     }
168   }
169 
GetTilingScheme()170   const TilingScheme& GetTilingScheme() const { return tiling_scheme_; }
171 
GetNumPartialResults()172   int GetNumPartialResults() const { return num_partial_results_; }
IsRaceFree()173   bool IsRaceFree() const { return is_race_free_; }
174 
175  private:
176   friend class ReductionCodegenState;
177 
178   const TilingScheme tiling_scheme_;
179   int num_partial_results_;
180   bool is_row_reduction_;
181   bool is_race_free_;
182 };
183 
184 class ReductionCodegenState {
185  public:
186   struct ReductionCalculationState {
187     llvm::GlobalVariable* shared_cache;
188     llvm::Value* initial_value;
189     llvm::AllocaInst* partial_result_address;
190     llvm::AllocaInst* input_address;
191     llvm_ir::ElementGenerator input_gen;
192   };
193 
ReductionCodegenState(const ReductionCodegenInfo & reduction_codegen_info)194   explicit ReductionCodegenState(
195       const ReductionCodegenInfo& reduction_codegen_info)
196       : reduction_codegen_info_(reduction_codegen_info) {}
197 
GetTilingScheme()198   const TilingScheme& GetTilingScheme() const {
199     return reduction_codegen_info_.tiling_scheme_;
200   }
201 
GetNumPartialResults()202   int GetNumPartialResults() const {
203     return reduction_codegen_info_.num_partial_results_;
204   }
205 
IsRowReduction()206   bool IsRowReduction() const {
207     return reduction_codegen_info_.is_row_reduction_;
208   }
209 
IsRaceFree()210   bool IsRaceFree() const { return reduction_codegen_info_.IsRaceFree(); }
211 
GetCalculationStateFor(const HloInstruction * instruction,int operand_idx)212   const ReductionCalculationState& GetCalculationStateFor(
213       const HloInstruction* instruction, int operand_idx) const {
214     const ReductionOpState& op_state = state_.at(instruction);
215     CHECK_LT(operand_idx, op_state.size());
216     return op_state[operand_idx];
217   }
218 
SetCalculationStateFor(const ReductionCalculationState & calculation_state,const HloInstruction * instruction,int operand_idx)219   void SetCalculationStateFor(
220       const ReductionCalculationState& calculation_state,
221       const HloInstruction* instruction, int operand_idx) {
222     ReductionOpState& op_state = state_[instruction];
223     CHECK_EQ(operand_idx, op_state.size());
224     op_state.push_back(calculation_state);
225   }
226 
227  private:
228   ReductionCodegenInfo reduction_codegen_info_;
229 
230   // One state per reduction operand.
231   using ReductionOpState = absl::InlinedVector<ReductionCalculationState, 2>;
232 
233   // HloInstruction -> operand_idx -> cache
234   absl::flat_hash_map<const HloInstruction*, ReductionOpState> state_;
235 };
236 
237 }  // end namespace gpu
238 }  // end namespace xla
239 
240 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_KERNEL_MAPPING_SCHEME_H_
241