1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
17
18 #include <memory>
19
20 #include "tensorflow/core/platform/errors.h"
21 #include "tensorflow/core/platform/logging.h"
22 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
23 #include "llvm/IR/Intrinsics.h"
24 #include "llvm/IR/Value.h"
25 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
27 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
28 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30
31 namespace xla {
32 namespace gpu {
33
ParallelLoopEmitter(llvm_ir::BodyEmitter body_emitter,const Shape & shape,const LaunchDimensions & launch_dimensions,llvm::IRBuilder<> * b,LaunchDimensionsConfig launch_config)34 ParallelLoopEmitter::ParallelLoopEmitter(
35 llvm_ir::BodyEmitter body_emitter, const Shape& shape,
36 const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b,
37 LaunchDimensionsConfig launch_config)
38 : launch_dimensions_(launch_dimensions),
39 launch_config_(launch_config),
40 body_emitter_(body_emitter),
41 shape_(shape),
42 b_(b) {}
43
ParallelLoopEmitter(const llvm_ir::ElementGenerator & target_element_generator,absl::Span<const llvm_ir::IrArray> target_arrays,const LaunchDimensions & launch_dimensions,llvm::IRBuilder<> * b,LaunchDimensionsConfig launch_config)44 ParallelLoopEmitter::ParallelLoopEmitter(
45 const llvm_ir::ElementGenerator& target_element_generator,
46 absl::Span<const llvm_ir::IrArray> target_arrays,
47 const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b,
48
49 LaunchDimensionsConfig launch_config)
50 : launch_dimensions_(launch_dimensions),
51 launch_config_(launch_config),
52 body_emitter_(
53 llvm_ir::MakeBodyEmitter(target_element_generator, target_arrays, b,
54 /*is_tuple=*/target_arrays.size() > 1)),
55 shape_(target_arrays[0].GetShape()),
56 b_(b) {}
57
58 ParallelLoopEmitter::LinearBaseAndThreadIdx
EmitLinearBaseAndThreadIdx(llvm::Type * index_type,llvm::Value * base_index)59 ParallelLoopEmitter::EmitLinearBaseAndThreadIdx(llvm::Type* index_type,
60 llvm::Value* base_index) {
61 llvm::Value* block_id =
62 EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b_);
63 llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_counts().x,
64 static_cast<llvm::Instruction*>(block_id));
65 block_id = b_->CreateZExtOrTrunc(block_id, index_type, "block_id");
66
67 // Per the PTX documentation:
68 // "It is guaranteed that [...] 0 <= %tid.x < %ntid.x"
69 llvm::Value* thread_id_x =
70 EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b_);
71 llvm_ir::AddRangeMetadata(0, launch_dimensions_.thread_counts_per_block().x,
72 static_cast<llvm::Instruction*>(thread_id_x));
73 thread_id_x = b_->CreateZExtOrTrunc(thread_id_x, index_type, "thread_id_x");
74
75 const int unroll_factor =
76 launch_config_.unroll_factor > 1 ? launch_config_.unroll_factor : 1;
77
78 // Linear base is different for logical order vs physical order stores.
79 // For logical, linear_base = block_id*num_threads*unroll + thread_idx
80 // For physical, linear_base = (block_id*num_threads + thread_idx)*unroll
81 int block_id_multiplier =
82 launch_config_.logical_order
83 ? launch_dimensions_.total_nb_threads() * unroll_factor
84 : launch_dimensions_.total_nb_threads();
85
86 llvm::Value* linear_index_base = b_->CreateMul(
87 block_id, llvm::ConstantInt::get(index_type, block_id_multiplier), "",
88 /*HasNUW=*/true, /*HasNSW=*/true);
89
90 linear_index_base =
91 b_->CreateAdd(linear_index_base, thread_id_x, "linear_index",
92 /*HasNUW=*/true, /*HasNSW=*/true);
93
94 if (launch_dimensions_.thread_counts_per_block().y > 1) {
95 CHECK(!launch_config_.logical_order);
96 llvm::Value* thread_id_y =
97 EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdy, {}, {}, b_);
98 llvm_ir::AddRangeMetadata(0, launch_dimensions_.thread_counts_per_block().y,
99 static_cast<llvm::Instruction*>(thread_id_y));
100 thread_id_y = b_->CreateZExtOrTrunc(thread_id_y, index_type, "thread_id_y");
101 linear_index_base = b_->CreateAdd(
102 linear_index_base,
103 b_->CreateMul(
104 thread_id_y,
105 llvm::ConstantInt::get(
106 index_type, launch_dimensions_.thread_counts_per_block().x),
107 "",
108 /*HasNUW=*/true, /*HasNSW=*/true),
109 "",
110 /*HasNUW=*/true, /*HasNSW=*/true);
111 }
112
113 // Add an @llvm.assume(linear_index < threads_per_block * num_blocks).
114 //
115 // This might seem obvious from the computation above, but LLVM does not
116 // currently determine the range of linear_index precisely. InstCombine uses
117 // known-bits, which, when applied to the task of determining a value's range,
118 // is imprecise for everything other than powers of 2. And
119 // CorrelatedValuePropagation is, as a cost-saving measure, disabled for
120 // conditions in the same basic block as their operands.
121 llvm_ir::EmitCallToIntrinsic(
122 llvm::Intrinsic::assume,
123 {b_->CreateICmpULT(
124 linear_index_base,
125 llvm::ConstantInt::get(
126 index_type,
127 block_id_multiplier * launch_dimensions_.block_counts().x),
128 "linear_index_in_range")},
129 {}, b_);
130
131 if (!launch_config_.logical_order && launch_config_.unroll_factor > 1) {
132 linear_index_base = b_->CreateMul(
133 linear_index_base,
134 llvm::ConstantInt::get(index_type, launch_config_.unroll_factor),
135 "linear_index_base", /*HasNUW=*/true, /*HasNSW=*/true);
136 }
137
138 if (base_index != nullptr) {
139 linear_index_base =
140 b_->CreateAdd(linear_index_base, base_index, "linear_index_plus_base",
141 /*HasNUW=*/true, /*HasNSW=*/true);
142 }
143 return {linear_index_base, thread_id_x};
144 }
145
146 std::vector<llvm_ir::IrArray::Index>
EmitLogicalIndexAndSetExitBasicBlock(absl::string_view loop_name,llvm::Type * index_type,llvm::Value * base_index)147 ParallelLoopEmitter::EmitLogicalIndexAndSetExitBasicBlock(
148 absl::string_view loop_name, llvm::Type* index_type,
149 llvm::Value* base_index) {
150 std::vector<llvm_ir::IrArray::Index> array_indices;
151
152 LinearBaseAndThreadIdx base_and_threadidx =
153 EmitLinearBaseAndThreadIdx(index_type, base_index);
154 llvm::Value* linear_index_base = base_and_threadidx.linear_base;
155 const int unroll_factor = launch_config_.unroll_factor;
156
157 llvm::Value* linear_base = linear_index_base;
158
159 for (int i = 0; i < unroll_factor; ++i) {
160 std::vector<llvm::Value*> multidim(shape_.rank(), nullptr);
161 if (i > 0) {
162 llvm::Value* addend = llvm::ConstantInt::get(
163 index_type, launch_dimensions_.total_nb_threads());
164 linear_base =
165 b_->CreateAdd(linear_base, addend, absl::StrCat("linear_index", i),
166 /*HasNUW=*/true, /*HasNSW=*/true);
167 }
168 auto dims = shape_.dimensions();
169 llvm::Value* last_val = linear_base;
170 for (int j = dims.size() - 1; j >= 0; j--) {
171 multidim[j] =
172 b_->CreateURem(last_val, llvm::ConstantInt::get(index_type, dims[j]));
173 last_val =
174 b_->CreateUDiv(last_val, llvm::ConstantInt::get(index_type, dims[j]));
175 }
176 array_indices.emplace_back(multidim, shape_, index_type);
177 }
178
179 // We don't need to do bounds checking because this method is only
180 // triggered for cases where we have already verified the bounds.
181 llvm::BasicBlock* current_block = b_->GetInsertBlock();
182 llvm::BasicBlock* body_block =
183 llvm_ir::CreateBasicBlock(nullptr, "fusion-body", b_);
184 if (current_block->getTerminator() == nullptr) {
185 exit_bb_ = llvm_ir::CreateBasicBlock(nullptr, "after-fusion-body", b_);
186 } else {
187 exit_bb_ = current_block->splitBasicBlock(b_->GetInsertPoint(),
188 "after-fusion-body");
189 current_block->getTerminator()->eraseFromParent();
190 }
191 b_->SetInsertPoint(current_block);
192 b_->CreateBr(body_block);
193 b_->SetInsertPoint(body_block);
194 b_->CreateBr(exit_bb_);
195
196 // Set IR builder insertion point to the body of the if structure.
197 llvm_ir::SetToFirstInsertPoint(body_block, b_);
198 return array_indices;
199 }
200
201 std::vector<llvm_ir::IrArray::Index>
EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,llvm::Type * index_type,llvm::Value * base_index)202 ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
203 llvm::Type* index_type,
204 llvm::Value* base_index) {
205 // Emit the following code in LLVM IR:
206 // linear_index = blockIdx.x * blockDim.x * blockDim.y [+ threadIdx.y *
207 // blockDim.x] + threadIdx.x; if (linear_index < num_elements) {
208 // array_index = LinearIndexToMultidimensionalIndex(shape_, linear_index);
209 // ...
210 // }
211 // The part between [] are added only if blockDim.y > 1.
212 // blockIdx.y and gridDim.y are always 1.
213
214 // Per the PTX documentation:
215 // "It is guaranteed that [...] 0 <= %ctaid.x < %nctaid.x"
216 //
217 // %nctaid.x is currently specified as 2147483647.
218 if (launch_dimensions_.thread_counts_per_block().y > 1) {
219 // When blockDim.y > 1, then we are in the small row case. Each
220 // blockDim.x do exatly to one row and blockDim.y map to some
221 // consecutive row. This prevents too small block size that isn't
222 // efficient.
223 CHECK(launch_config_.row_vectorized);
224 CHECK_EQ(shape_.dimensions().back(),
225 launch_dimensions_.thread_counts_per_block().x *
226 launch_config_.unroll_factor);
227 }
228 CHECK_EQ(launch_dimensions_.thread_counts_per_block().z, 1);
229 CHECK_EQ(launch_dimensions_.block_counts().y, 1);
230 CHECK_EQ(launch_dimensions_.block_counts().z, 1);
231 VLOG(3) << "EmitIndexAndSetExitBasicBlock unroll_factor "
232 << launch_config_.unroll_factor;
233 CHECK_NE(index_type, nullptr);
234
235 if (launch_config_.logical_order) {
236 return EmitLogicalIndexAndSetExitBasicBlock(loop_name, index_type,
237 base_index);
238 }
239
240 std::vector<llvm_ir::IrArray::Index> array_indices;
241 LinearBaseAndThreadIdx linear_base_and_thread_idx =
242 EmitLinearBaseAndThreadIdx(index_type, base_index);
243
244 llvm::Value* linear_index_base = linear_base_and_thread_idx.linear_base;
245 llvm::Value* thread_id_x = linear_base_and_thread_idx.thread_idx;
246
247 // When enable_row_index is true, it means the inner most dimensions
248 // match the block sizes. So we can generate a simpler indexing
249 // for that dimensions. This helps LLVM generate vectorized codes
250 // in that cases.
251 llvm::Value* row_index = nullptr;
252 if (!launch_config_.row_vectorized) {
253 array_indices.emplace_back(linear_index_base, shape_, b_);
254 } else {
255 // Simpler index for row computation.
256 // This will allow LLVM to vectorize.
257 row_index = b_->CreateMul(
258 thread_id_x,
259 llvm::ConstantInt::get(index_type, launch_config_.unroll_factor),
260 "row_index", /*HasNUW=*/true, /*HasNSW=*/true);
261 std::vector<llvm::Value*> multidim(shape_.rank(), nullptr);
262 multidim.back() = row_index;
263 array_indices.emplace_back(linear_index_base, multidim, shape_, b_);
264 }
265
266 for (int i = 1; i < launch_config_.unroll_factor; ++i) {
267 llvm::Value* linear_index =
268 b_->CreateAdd(linear_index_base, llvm::ConstantInt::get(index_type, i),
269 absl::StrCat("linear_index", i),
270 /*HasNUW=*/true, /*HasNSW=*/true);
271 if (!launch_config_.row_vectorized) {
272 array_indices.emplace_back(linear_index, shape_, b_);
273 } else {
274 std::vector<llvm::Value*> multidim(shape_.rank(), nullptr);
275 multidim.back() = b_->CreateAdd(
276 row_index, llvm::ConstantInt::get(index_type, i),
277 absl::StrCat("row_index_plus", i), /*HasNUW=*/true, /*HasNSW=*/true);
278 array_indices.emplace_back(linear_index, multidim, shape_, b_);
279 }
280 }
281
282 auto if_in_bounds = llvm_ir::EmitIfThenElse(
283 b_->CreateICmpULT(
284 linear_index_base,
285 llvm::ConstantInt::get(index_type, ShapeUtil::ElementsIn(shape_))),
286 llvm_ir::IrName(loop_name, "in_bounds"), b_, false);
287
288 // Set exit_bb_ to the exit block of the if structure.
289 exit_bb_ = if_in_bounds.after_block;
290 CHECK_NE(nullptr, exit_bb_);
291
292 // Set IR builder insertion point to the body of the if structure.
293 llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, b_);
294
295 return array_indices;
296 }
297
EmitSerialLoop(absl::string_view loop_name,llvm::Type * index_type,llvm::Value * base_indvar)298 Status ParallelLoopEmitter::EmitSerialLoop(absl::string_view loop_name,
299 llvm::Type* index_type,
300 llvm::Value* base_indvar) {
301 for (const llvm_ir::IrArray::Index& array_index :
302 EmitIndexAndSetExitBasicBlock(loop_name, index_type, base_indvar)) {
303 TF_RETURN_IF_ERROR(body_emitter_(array_index));
304 }
305 return OkStatus();
306 }
307
EmitLoop(absl::string_view loop_name,llvm::Type * index_type)308 Status ParallelLoopEmitter::EmitLoop(absl::string_view loop_name,
309 llvm::Type* index_type) {
310 if (index_type == nullptr) {
311 index_type = b_->getInt64Ty();
312 }
313 int64_t total_threads = launch_dimensions_.launch_bound();
314 int64_t num_elements = ShapeUtil::ElementsIn(shape_);
315 // If all the elements are handled by the current threads, no need
316 // to add a loop inside the kernel.
317 if (total_threads * launch_config_.unroll_factor >= num_elements) {
318 VLOG(1) << "No loops inside the kernel";
319 TF_RETURN_IF_ERROR(EmitSerialLoop(loop_name, index_type));
320 } else {
321 KernelSupportLibrary ksl(b_, llvm_ir::UnrollMode::kDefaultUnroll);
322 auto constant = [&](int64_t val) {
323 return llvm::ConstantInt::get(index_type, val);
324 };
325
326 TF_RETURN_IF_ERROR(ksl.ForWithStatus(
327 "loop", constant(0), constant(num_elements),
328 constant(total_threads * launch_config_.unroll_factor),
329 [&](llvm::Value* base_indvar) {
330 return EmitSerialLoop(loop_name, index_type, base_indvar);
331 }));
332 }
333
334 // Set the insertion point of b_ to the loop exit, so that
335 // code emitted for later instructions will be correctly placed.
336 if (exit_bb_ != nullptr) {
337 b_->SetInsertPoint(exit_bb_);
338 }
339 return OkStatus();
340 }
341
342 } // namespace gpu
343 } // namespace xla
344