xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
17 
18 #include <memory>
19 
20 #include "tensorflow/core/platform/errors.h"
21 #include "tensorflow/core/platform/logging.h"
22 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
23 #include "llvm/IR/Intrinsics.h"
24 #include "llvm/IR/Value.h"
25 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
27 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
28 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 
31 namespace xla {
32 namespace gpu {
33 
ParallelLoopEmitter(llvm_ir::BodyEmitter body_emitter,const Shape & shape,const LaunchDimensions & launch_dimensions,llvm::IRBuilder<> * b,LaunchDimensionsConfig launch_config)34 ParallelLoopEmitter::ParallelLoopEmitter(
35     llvm_ir::BodyEmitter body_emitter, const Shape& shape,
36     const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b,
37     LaunchDimensionsConfig launch_config)
38     : launch_dimensions_(launch_dimensions),
39       launch_config_(launch_config),
40       body_emitter_(body_emitter),
41       shape_(shape),
42       b_(b) {}
43 
ParallelLoopEmitter(const llvm_ir::ElementGenerator & target_element_generator,absl::Span<const llvm_ir::IrArray> target_arrays,const LaunchDimensions & launch_dimensions,llvm::IRBuilder<> * b,LaunchDimensionsConfig launch_config)44 ParallelLoopEmitter::ParallelLoopEmitter(
45     const llvm_ir::ElementGenerator& target_element_generator,
46     absl::Span<const llvm_ir::IrArray> target_arrays,
47     const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b,
48 
49     LaunchDimensionsConfig launch_config)
50     : launch_dimensions_(launch_dimensions),
51       launch_config_(launch_config),
52       body_emitter_(
53           llvm_ir::MakeBodyEmitter(target_element_generator, target_arrays, b,
54                                    /*is_tuple=*/target_arrays.size() > 1)),
55       shape_(target_arrays[0].GetShape()),
56       b_(b) {}
57 
58 ParallelLoopEmitter::LinearBaseAndThreadIdx
EmitLinearBaseAndThreadIdx(llvm::Type * index_type,llvm::Value * base_index)59 ParallelLoopEmitter::EmitLinearBaseAndThreadIdx(llvm::Type* index_type,
60                                                 llvm::Value* base_index) {
61   llvm::Value* block_id =
62       EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b_);
63   llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_counts().x,
64                             static_cast<llvm::Instruction*>(block_id));
65   block_id = b_->CreateZExtOrTrunc(block_id, index_type, "block_id");
66 
67   // Per the PTX documentation:
68   //   "It is guaranteed that [...] 0  <=  %tid.x <  %ntid.x"
69   llvm::Value* thread_id_x =
70       EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b_);
71   llvm_ir::AddRangeMetadata(0, launch_dimensions_.thread_counts_per_block().x,
72                             static_cast<llvm::Instruction*>(thread_id_x));
73   thread_id_x = b_->CreateZExtOrTrunc(thread_id_x, index_type, "thread_id_x");
74 
75   const int unroll_factor =
76       launch_config_.unroll_factor > 1 ? launch_config_.unroll_factor : 1;
77 
78   // Linear base is different for logical order vs physical order stores.
79   // For logical,  linear_base = block_id*num_threads*unroll + thread_idx
80   // For physical, linear_base = (block_id*num_threads + thread_idx)*unroll
81   int block_id_multiplier =
82       launch_config_.logical_order
83           ? launch_dimensions_.total_nb_threads() * unroll_factor
84           : launch_dimensions_.total_nb_threads();
85 
86   llvm::Value* linear_index_base = b_->CreateMul(
87       block_id, llvm::ConstantInt::get(index_type, block_id_multiplier), "",
88       /*HasNUW=*/true, /*HasNSW=*/true);
89 
90   linear_index_base =
91       b_->CreateAdd(linear_index_base, thread_id_x, "linear_index",
92                     /*HasNUW=*/true, /*HasNSW=*/true);
93 
94   if (launch_dimensions_.thread_counts_per_block().y > 1) {
95     CHECK(!launch_config_.logical_order);
96     llvm::Value* thread_id_y =
97         EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdy, {}, {}, b_);
98     llvm_ir::AddRangeMetadata(0, launch_dimensions_.thread_counts_per_block().y,
99                               static_cast<llvm::Instruction*>(thread_id_y));
100     thread_id_y = b_->CreateZExtOrTrunc(thread_id_y, index_type, "thread_id_y");
101     linear_index_base = b_->CreateAdd(
102         linear_index_base,
103         b_->CreateMul(
104             thread_id_y,
105             llvm::ConstantInt::get(
106                 index_type, launch_dimensions_.thread_counts_per_block().x),
107             "",
108             /*HasNUW=*/true, /*HasNSW=*/true),
109         "",
110         /*HasNUW=*/true, /*HasNSW=*/true);
111   }
112 
113   // Add an @llvm.assume(linear_index < threads_per_block * num_blocks).
114   //
115   // This might seem obvious from the computation above, but LLVM does not
116   // currently determine the range of linear_index precisely.  InstCombine uses
117   // known-bits, which, when applied to the task of determining a value's range,
118   // is imprecise for everything other than powers of 2.  And
119   // CorrelatedValuePropagation is, as a cost-saving measure, disabled for
120   // conditions in the same basic block as their operands.
121   llvm_ir::EmitCallToIntrinsic(
122       llvm::Intrinsic::assume,
123       {b_->CreateICmpULT(
124           linear_index_base,
125           llvm::ConstantInt::get(
126               index_type,
127               block_id_multiplier * launch_dimensions_.block_counts().x),
128           "linear_index_in_range")},
129       {}, b_);
130 
131   if (!launch_config_.logical_order && launch_config_.unroll_factor > 1) {
132     linear_index_base = b_->CreateMul(
133         linear_index_base,
134         llvm::ConstantInt::get(index_type, launch_config_.unroll_factor),
135         "linear_index_base", /*HasNUW=*/true, /*HasNSW=*/true);
136   }
137 
138   if (base_index != nullptr) {
139     linear_index_base =
140         b_->CreateAdd(linear_index_base, base_index, "linear_index_plus_base",
141                       /*HasNUW=*/true, /*HasNSW=*/true);
142   }
143   return {linear_index_base, thread_id_x};
144 }
145 
146 std::vector<llvm_ir::IrArray::Index>
EmitLogicalIndexAndSetExitBasicBlock(absl::string_view loop_name,llvm::Type * index_type,llvm::Value * base_index)147 ParallelLoopEmitter::EmitLogicalIndexAndSetExitBasicBlock(
148     absl::string_view loop_name, llvm::Type* index_type,
149     llvm::Value* base_index) {
150   std::vector<llvm_ir::IrArray::Index> array_indices;
151 
152   LinearBaseAndThreadIdx base_and_threadidx =
153       EmitLinearBaseAndThreadIdx(index_type, base_index);
154   llvm::Value* linear_index_base = base_and_threadidx.linear_base;
155   const int unroll_factor = launch_config_.unroll_factor;
156 
157   llvm::Value* linear_base = linear_index_base;
158 
159   for (int i = 0; i < unroll_factor; ++i) {
160     std::vector<llvm::Value*> multidim(shape_.rank(), nullptr);
161     if (i > 0) {
162       llvm::Value* addend = llvm::ConstantInt::get(
163           index_type, launch_dimensions_.total_nb_threads());
164       linear_base =
165           b_->CreateAdd(linear_base, addend, absl::StrCat("linear_index", i),
166                         /*HasNUW=*/true, /*HasNSW=*/true);
167     }
168     auto dims = shape_.dimensions();
169     llvm::Value* last_val = linear_base;
170     for (int j = dims.size() - 1; j >= 0; j--) {
171       multidim[j] =
172           b_->CreateURem(last_val, llvm::ConstantInt::get(index_type, dims[j]));
173       last_val =
174           b_->CreateUDiv(last_val, llvm::ConstantInt::get(index_type, dims[j]));
175     }
176     array_indices.emplace_back(multidim, shape_, index_type);
177   }
178 
179   // We don't need to do bounds checking because this method is only
180   // triggered for cases where we have already verified the bounds.
181   llvm::BasicBlock* current_block = b_->GetInsertBlock();
182   llvm::BasicBlock* body_block =
183       llvm_ir::CreateBasicBlock(nullptr, "fusion-body", b_);
184   if (current_block->getTerminator() == nullptr) {
185     exit_bb_ = llvm_ir::CreateBasicBlock(nullptr, "after-fusion-body", b_);
186   } else {
187     exit_bb_ = current_block->splitBasicBlock(b_->GetInsertPoint(),
188                                               "after-fusion-body");
189     current_block->getTerminator()->eraseFromParent();
190   }
191   b_->SetInsertPoint(current_block);
192   b_->CreateBr(body_block);
193   b_->SetInsertPoint(body_block);
194   b_->CreateBr(exit_bb_);
195 
196   // Set IR builder insertion point to the body of the if structure.
197   llvm_ir::SetToFirstInsertPoint(body_block, b_);
198   return array_indices;
199 }
200 
201 std::vector<llvm_ir::IrArray::Index>
EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,llvm::Type * index_type,llvm::Value * base_index)202 ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
203                                                    llvm::Type* index_type,
204                                                    llvm::Value* base_index) {
205   // Emit the following code in LLVM IR:
206   //   linear_index = blockIdx.x * blockDim.x * blockDim.y [+ threadIdx.y *
207   //   blockDim.x] + threadIdx.x; if (linear_index < num_elements) {
208   //     array_index = LinearIndexToMultidimensionalIndex(shape_, linear_index);
209   //     ...
210   //   }
211   // The part between [] are added only if blockDim.y > 1.
212   // blockIdx.y and gridDim.y are always 1.
213 
214   // Per the PTX documentation:
215   //   "It is guaranteed that [...] 0  <=  %ctaid.x <  %nctaid.x"
216   //
217   // %nctaid.x is currently specified as 2147483647.
218   if (launch_dimensions_.thread_counts_per_block().y > 1) {
219     // When blockDim.y > 1, then we are in the small row case. Each
220     // blockDim.x do exatly to one row and blockDim.y map to some
221     // consecutive row. This prevents too small block size that isn't
222     // efficient.
223     CHECK(launch_config_.row_vectorized);
224     CHECK_EQ(shape_.dimensions().back(),
225              launch_dimensions_.thread_counts_per_block().x *
226                  launch_config_.unroll_factor);
227   }
228   CHECK_EQ(launch_dimensions_.thread_counts_per_block().z, 1);
229   CHECK_EQ(launch_dimensions_.block_counts().y, 1);
230   CHECK_EQ(launch_dimensions_.block_counts().z, 1);
231   VLOG(3) << "EmitIndexAndSetExitBasicBlock unroll_factor "
232           << launch_config_.unroll_factor;
233   CHECK_NE(index_type, nullptr);
234 
235   if (launch_config_.logical_order) {
236     return EmitLogicalIndexAndSetExitBasicBlock(loop_name, index_type,
237                                                 base_index);
238   }
239 
240   std::vector<llvm_ir::IrArray::Index> array_indices;
241   LinearBaseAndThreadIdx linear_base_and_thread_idx =
242       EmitLinearBaseAndThreadIdx(index_type, base_index);
243 
244   llvm::Value* linear_index_base = linear_base_and_thread_idx.linear_base;
245   llvm::Value* thread_id_x = linear_base_and_thread_idx.thread_idx;
246 
247   // When enable_row_index is true, it means the inner most dimensions
248   // match the block sizes.  So we can generate a simpler indexing
249   // for that dimensions.  This helps LLVM generate vectorized codes
250   // in that cases.
251   llvm::Value* row_index = nullptr;
252   if (!launch_config_.row_vectorized) {
253     array_indices.emplace_back(linear_index_base, shape_, b_);
254   } else {
255     // Simpler index for row computation.
256     // This will allow LLVM to vectorize.
257     row_index = b_->CreateMul(
258         thread_id_x,
259         llvm::ConstantInt::get(index_type, launch_config_.unroll_factor),
260         "row_index", /*HasNUW=*/true, /*HasNSW=*/true);
261     std::vector<llvm::Value*> multidim(shape_.rank(), nullptr);
262     multidim.back() = row_index;
263     array_indices.emplace_back(linear_index_base, multidim, shape_, b_);
264   }
265 
266   for (int i = 1; i < launch_config_.unroll_factor; ++i) {
267     llvm::Value* linear_index =
268         b_->CreateAdd(linear_index_base, llvm::ConstantInt::get(index_type, i),
269                       absl::StrCat("linear_index", i),
270                       /*HasNUW=*/true, /*HasNSW=*/true);
271     if (!launch_config_.row_vectorized) {
272       array_indices.emplace_back(linear_index, shape_, b_);
273     } else {
274       std::vector<llvm::Value*> multidim(shape_.rank(), nullptr);
275       multidim.back() = b_->CreateAdd(
276           row_index, llvm::ConstantInt::get(index_type, i),
277           absl::StrCat("row_index_plus", i), /*HasNUW=*/true, /*HasNSW=*/true);
278       array_indices.emplace_back(linear_index, multidim, shape_, b_);
279     }
280   }
281 
282   auto if_in_bounds = llvm_ir::EmitIfThenElse(
283       b_->CreateICmpULT(
284           linear_index_base,
285           llvm::ConstantInt::get(index_type, ShapeUtil::ElementsIn(shape_))),
286       llvm_ir::IrName(loop_name, "in_bounds"), b_, false);
287 
288   // Set exit_bb_ to the exit block of the if structure.
289   exit_bb_ = if_in_bounds.after_block;
290   CHECK_NE(nullptr, exit_bb_);
291 
292   // Set IR builder insertion point to the body of the if structure.
293   llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, b_);
294 
295   return array_indices;
296 }
297 
EmitSerialLoop(absl::string_view loop_name,llvm::Type * index_type,llvm::Value * base_indvar)298 Status ParallelLoopEmitter::EmitSerialLoop(absl::string_view loop_name,
299                                            llvm::Type* index_type,
300                                            llvm::Value* base_indvar) {
301   for (const llvm_ir::IrArray::Index& array_index :
302        EmitIndexAndSetExitBasicBlock(loop_name, index_type, base_indvar)) {
303     TF_RETURN_IF_ERROR(body_emitter_(array_index));
304   }
305   return OkStatus();
306 }
307 
EmitLoop(absl::string_view loop_name,llvm::Type * index_type)308 Status ParallelLoopEmitter::EmitLoop(absl::string_view loop_name,
309                                      llvm::Type* index_type) {
310   if (index_type == nullptr) {
311     index_type = b_->getInt64Ty();
312   }
313   int64_t total_threads = launch_dimensions_.launch_bound();
314   int64_t num_elements = ShapeUtil::ElementsIn(shape_);
315   // If all the elements are handled by the current threads, no need
316   // to add a loop inside the kernel.
317   if (total_threads * launch_config_.unroll_factor >= num_elements) {
318     VLOG(1) << "No loops inside the kernel";
319     TF_RETURN_IF_ERROR(EmitSerialLoop(loop_name, index_type));
320   } else {
321     KernelSupportLibrary ksl(b_, llvm_ir::UnrollMode::kDefaultUnroll);
322     auto constant = [&](int64_t val) {
323       return llvm::ConstantInt::get(index_type, val);
324     };
325 
326     TF_RETURN_IF_ERROR(ksl.ForWithStatus(
327         "loop", constant(0), constant(num_elements),
328         constant(total_threads * launch_config_.unroll_factor),
329         [&](llvm::Value* base_indvar) {
330           return EmitSerialLoop(loop_name, index_type, base_indvar);
331         }));
332   }
333 
334   // Set the insertion point of b_ to the loop exit, so that
335   // code emitted for later instructions will be correctly placed.
336   if (exit_bb_ != nullptr) {
337     b_->SetInsertPoint(exit_bb_);
338   }
339   return OkStatus();
340 }
341 
342 }  // namespace gpu
343 }  // namespace xla
344