1 /*Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <cstring>
21 #include <iterator>
22 #include <map>
23 #include <memory>
24 #include <numeric>
25 #include <optional>
26 #include <string>
27 #include <type_traits>
28 #include <utility>
29 #include <vector>
30
31 #include "absl/algorithm/container.h"
32 #include "absl/container/flat_hash_set.h"
33 #include "absl/container/inlined_vector.h"
34 #include "absl/strings/str_cat.h"
35 #include "absl/strings/str_format.h"
36 #include "absl/types/span.h"
37 #include "llvm/ADT/APInt.h"
38 #include "llvm/ADT/StringRef.h"
39 #include "llvm/IR/BasicBlock.h"
40 #include "llvm/IR/Function.h"
41 #include "llvm/IR/IRBuilder.h"
42 #include "llvm/IR/Instructions.h"
43 #include "llvm/IR/LLVMContext.h"
44 #include "llvm/IR/Module.h"
45 #include "llvm/Linker/Linker.h"
46 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
47 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
48 #include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project
49 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
50 #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
51 #include "mlir/IR/Attributes.h" // from @llvm-project
52 #include "mlir/IR/Builders.h" // from @llvm-project
53 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
54 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
55 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
56 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project
57 #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" // from @llvm-project
58 #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" // from @llvm-project
59 #include "mlir/Target/LLVMIR/Export.h" // from @llvm-project
60 #include "tensorflow/compiler/mlir/utils/name_utils.h"
61 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
62 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
63 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
64 #include "tensorflow/compiler/xla/layout_util.h"
65 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
66 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
67 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
68 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/gpu_passes.h"
69 #include "tensorflow/compiler/xla/permutation_util.h"
70 #include "tensorflow/compiler/xla/primitive_util.h"
71 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
72 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
73 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
74 #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
75 #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
76 #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
77 #include "tensorflow/compiler/xla/service/gpu/custom_call_thunk.h"
78 #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
79 #include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
80 #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
81 #include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h"
82 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
83 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
84 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
85 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
86 #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
87 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
88 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
89 #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h"
90 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
91 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
92 #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h"
93 #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h"
94 #include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
95 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
96 #include "tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h"
97 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.h"
98 #include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h"
99 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
100 #include "tensorflow/compiler/xla/service/gpu/replica_id_thunk.h"
101 #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
102 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
103 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
104 #include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
105 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
106 #include "tensorflow/compiler/xla/service/hlo_computation.h"
107 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
108 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
109 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
110 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
111 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
112 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
113 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
114 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
115 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
116 #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
117 #include "tensorflow/compiler/xla/service/name_uniquer.h"
118 #include "tensorflow/compiler/xla/shape.h"
119 #include "tensorflow/compiler/xla/shape_util.h"
120 #include "tensorflow/compiler/xla/status_macros.h"
121 #include "tensorflow/compiler/xla/union_find.h"
122 #include "tensorflow/compiler/xla/util.h"
123 #include "tensorflow/compiler/xla/xla_data.pb.h"
124 #include "tensorflow/core/platform/errors.h"
125 #include "tensorflow/core/platform/human_readable_json.h"
126 #include "tensorflow/core/platform/logging.h"
127
128 #if GOOGLE_CUDA
129 #include "tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h"
130 #endif // GOOGLE_CUDA
131
132 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
133 #include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
134 #include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
135 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
136
137 namespace xla {
138 namespace gpu {
139
140 namespace {
141
142 using absl::InlinedVector;
143 using absl::StrCat;
144 using llvm_ir::IrArray;
145 using llvm_ir::IrName;
146 using std::optional;
147
148 const auto kDimX = TilingScheme::DimX;
149 const auto kDimY = TilingScheme::DimY;
150 const auto kDimZ = TilingScheme::DimZ;
151 const auto kDimTot = TilingScheme::DimTot;
152
153 const auto kLinearIndexingX = TilingScheme::LinearIndexingX;
154 const auto kStridedIndexingX = TilingScheme::StridedIndexingX;
155
AnnotateWithInt32Value(std::string name,int64_t value,const std::string & kernel_name,llvm::Module * llvm_module)156 void AnnotateWithInt32Value(std::string name, int64_t value,
157 const std::string& kernel_name,
158 llvm::Module* llvm_module) {
159 llvm::NamedMDNode* nvvm_annotations_node =
160 llvm_module->getOrInsertNamedMetadata("nvvm.annotations");
161 llvm::Function* ir_kernel = llvm_module->getFunction(kernel_name.c_str());
162 llvm::LLVMContext& llvm_context = llvm_module->getContext();
163
164 nvvm_annotations_node->addOperand(llvm::MDNode::get(
165 llvm_context,
166 {llvm::ConstantAsMetadata::get(ir_kernel),
167 llvm::MDString::get(llvm_context, name),
168 llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(
169 llvm::IntegerType::get(llvm_context, /*NumBits=*/32), value))}));
170 }
171
172 // Annotates the launch dimensions of the corresponding IR kernel in
173 // `llvm_module`.
AnnotateThunkLaunchDimensions(const LaunchDimensions & launch_dims,const std::string & kernel_name,llvm::Module * llvm_module)174 void AnnotateThunkLaunchDimensions(const LaunchDimensions& launch_dims,
175 const std::string& kernel_name,
176 llvm::Module* llvm_module) {
177 // Add __launch_bounds__ to metadata. This limits registers per thread to
178 // avoid out-of-resources launching errors.
179
180 // Our launch bounds are exact, so we can specify them as
181 // reqntid[xyz] rather than maxntid[xyz].
182 AnnotateWithInt32Value("reqntidx", launch_dims.thread_counts_per_block().x,
183 kernel_name, llvm_module);
184 if (launch_dims.thread_counts_per_block().y > 1) {
185 AnnotateWithInt32Value("reqntidy", launch_dims.thread_counts_per_block().y,
186 kernel_name, llvm_module);
187 }
188 if (launch_dims.thread_counts_per_block().z > 1) {
189 AnnotateWithInt32Value("reqntidz", launch_dims.thread_counts_per_block().z,
190 kernel_name, llvm_module);
191 }
192 }
193
BinarySearchDenseElementsAttr(mlir::DenseIntElementsAttr elements,int64_t v)194 bool BinarySearchDenseElementsAttr(mlir::DenseIntElementsAttr elements,
195 int64_t v) {
196 mlir::APInt value(sizeof(int64_t) * 8, v, /*isSigned=*/true);
197 return std::binary_search(
198 elements.begin(), elements.end(), value,
199 [](const mlir::APInt& x, const mlir::APInt& y) { return x.slt(y); });
200 }
201
MhloOpIsElementwise(mlir::Operation * op)202 bool MhloOpIsElementwise(mlir::Operation* op) {
203 CHECK(op->getDialect() ==
204 op->getContext()->getLoadedDialect<mlir::mhlo::MhloDialect>());
205 auto opcode = *MhloToHloOpcode(op);
206 if (HloInstruction::IsOpElementwise(opcode)) {
207 return true;
208 }
209 if (opcode == HloOpcode::kMap) {
210 int iota = 0;
211 for (const llvm::APInt& i :
212 mlir::cast<mlir::mhlo::MapOp>(op).dimensions()) {
213 if (i.getZExtValue() != iota) {
214 return false;
215 }
216 iota++;
217 }
218 return true;
219 }
220 // TODO(timshen): not sure about whether porting
221 // HloFusionInstruction::IsElementwiseImpl() is necessary. HandleFusion()
222 // doesn't use such information.
223 return false;
224 }
225
IsSingleInstructionFusion(mlir::lmhlo::FusionOp fusion)226 bool IsSingleInstructionFusion(mlir::lmhlo::FusionOp fusion) {
227 int instruction_count = 0;
228 for (mlir::Operation& instr : fusion.getRegion().front()) {
229 if (mlir::isa<mlir::lmhlo::TerminatorOp, mlir::mhlo::ReturnOp,
230 mlir::bufferization::ToTensorOp, mlir::memref::TensorStoreOp>(
231 &instr)) {
232 continue;
233 }
234 instruction_count++;
235 }
236 return instruction_count == 1;
237 }
238
MayPreventVectorization(mlir::Operation * op)239 bool MayPreventVectorization(mlir::Operation* op) {
240 // An empirically chosen constant: unrolling concat with a large amount of
241 // arguments causes excessive register spilling.
242 static constexpr int kMaxConcatArgumentsForUnrolling = 10;
243
244 auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(op);
245
246 for (mlir::Operation& instr : fusion.getRegion().front()) {
247 if (mlir::isa<mlir::lmhlo::TerminatorOp, mlir::mhlo::ReturnOp,
248 mlir::bufferization::ToTensorOp, mlir::memref::TensorStoreOp>(
249 &instr)) {
250 continue;
251 }
252
253 CHECK(instr.getDialect() ==
254 instr.getContext()->getLoadedDialect<mlir::mhlo::MhloDialect>())
255 << MlirToString(op);
256 switch (*MhloToHloOpcode(&instr)) {
257 case HloOpcode::kReduceWindow:
258 case HloOpcode::kSort:
259 case HloOpcode::kDot:
260 case HloOpcode::kSin:
261 case HloOpcode::kCos:
262 case HloOpcode::kPower:
263 case HloOpcode::kAtan2:
264 return true;
265 case HloOpcode::kConcatenate:
266 if (instr.getOperands().size() > kMaxConcatArgumentsForUnrolling) {
267 return true;
268 }
269 break;
270 case HloOpcode::kReduce:
271 if (instr.getNumResults() > 1) {
272 return true;
273 }
274 break;
275 default:
276 break;
277 }
278 }
279 return false;
280 }
281
282 // Computes the maximum valid unroll factor for a given instruction.
ComputeMaxUnrollFactor(mlir::Type type,const HloModuleConfig & hlo_module_config)283 int ComputeMaxUnrollFactor(mlir::Type type,
284 const HloModuleConfig& hlo_module_config) {
285 int max_unroll_factor =
286 hlo_module_config.debug_options().xla_gpu_max_kernel_unroll_factor();
287
288 // Find the largest possible power of two to unroll by.
289 // TODO(kramerb): Make this smarter.
290
291 auto shaped_type = type.cast<mlir::ShapedType>();
292 int64_t num_elements = std::accumulate(
293 shaped_type.getShape().begin(), shaped_type.getShape().end(), int64_t{1},
294 std::multiplies<int64_t>());
295 for (int i = max_unroll_factor; i > 1; i /= 2) {
296 if (num_elements % i == 0) {
297 return i;
298 }
299 }
300
301 // Cannot unroll.
302 return 1;
303 }
304
305 // Computes the maximum valid unroll factor for a given instruction.
ComputeMaxUnrollFactor(mlir::Operation * op,const HloModuleConfig & hlo_module_config)306 int ComputeMaxUnrollFactor(mlir::Operation* op,
307 const HloModuleConfig& hlo_module_config) {
308 mlir::Type element_shape = [&] {
309 if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
310 return fusion.getFusionRoots()[0]->getResult(0).getType();
311 }
312 return GetHloOutputs(op)[0].getType();
313 }();
314 return ComputeMaxUnrollFactor(element_shape, hlo_module_config);
315 }
316
317 // Returns the llvm type for the indices used in the kernel that contains the
318 // hlo instruction. Such indices include the index for the parallel loop and
319 // the indices for the tensors accessed by the kernel. The return type is i32
320 // iff the following conditions are met:
321 // . The launch_size of the kernel is within the range of i32.
322 // . The sizes of all the tensors accessed within the kernel are within the
323 // range of i32.
324 // Otherwise, the return type is i64.
GetIndexTypeForKernel(const HloInstruction * hlo,int64_t launch_size,llvm::IRBuilder<> * b)325 llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo,
326 int64_t launch_size, llvm::IRBuilder<>* b) {
327 // Find the unnested hlo instruction for which the kernel is generated for.
328 const HloInstruction* unnested_hlo = hlo;
329 const HloComputation* computation = hlo->parent();
330 if (computation->IsFusionComputation()) {
331 unnested_hlo = computation->FusionInstruction();
332 }
333
334 auto shape_in_range = [&](const Shape& s) {
335 bool in_range = true;
336 ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape,
337 const ShapeIndex& /*index*/) {
338 if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) {
339 in_range = false;
340 }
341 });
342
343 return in_range;
344 };
345
346 llvm::Type* i64_ty = b->getInt64Ty();
347 // Check launch dimension
348 if (!IsInt32(launch_size)) {
349 return i64_ty;
350 }
351
352 // Check the size of result tensors
353 if (!shape_in_range(unnested_hlo->shape())) {
354 return i64_ty;
355 }
356
357 auto hlo_shape_in_range = [&](const HloInstruction* operand) -> bool {
358 return shape_in_range(operand->shape());
359 };
360
361 // Check the size of input tensors
362 if (!absl::c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) {
363 return i64_ty;
364 }
365
366 // Check the size of the internal result tensors
367 if (unnested_hlo->opcode() == HloOpcode::kFusion) {
368 if (!absl::c_all_of(
369 unnested_hlo->fused_instructions_computation()->instructions(),
370 hlo_shape_in_range)) {
371 return i64_ty;
372 }
373 }
374
375 return b->getInt32Ty();
376 }
377
378 // The same as GetIndexTypeForKernel, but works with MLIR ops.
GetIndexTypeForKernel(mlir::Operation * op,int64_t launch_size,llvm::IRBuilder<> * b)379 llvm::Type* GetIndexTypeForKernel(mlir::Operation* op, int64_t launch_size,
380 llvm::IRBuilder<>* b) {
381 auto shape_in_range = [&](const Shape& s) {
382 bool in_range = true;
383 ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape,
384 const ShapeIndex& /*index*/) {
385 if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) {
386 in_range = false;
387 }
388 });
389
390 return in_range;
391 };
392
393 llvm::Type* i64_ty = b->getInt64Ty();
394 // Check launch dimension
395 if (!IsInt32(launch_size)) {
396 return i64_ty;
397 }
398
399 // Check the size of result tensors
400 for (auto result : GetHloOutputs(op)) {
401 if (!shape_in_range(GetShape(result))) {
402 return i64_ty;
403 }
404 }
405
406 auto hlo_shape_in_range = [&](mlir::Value operand) -> bool {
407 return shape_in_range(GetShape(operand));
408 };
409
410 // Check the size of input tensors
411 if (!absl::c_all_of(op->getOperands(), hlo_shape_in_range)) {
412 return i64_ty;
413 }
414
415 // Check the size of the internal result tensors
416 if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
417 auto result = fusion.getRegion().walk([&](mlir::Operation* op) {
418 for (mlir::Value result : op->getResults()) {
419 if (!hlo_shape_in_range(result)) {
420 return mlir::WalkResult::interrupt();
421 }
422 }
423 return mlir::WalkResult::advance();
424 });
425 if (result.wasInterrupted()) {
426 return i64_ty;
427 }
428 }
429
430 return b->getInt32Ty();
431 }
432
433 // Gets the input shape of the ROOT slices, which will be used as the kernel
434 // launch dims. The slice input fusion requires the input shapes of the ROOT
435 // slices to be the same although the (slice) output shapes can be different.
436 //
437 // Returns the input shape of the ROOT slices if all the input shapes of ROOT
438 // slices are the same and the slices are non-strided. Otherwise, returns
439 // FailedPrecondition.
GetConsistentInputShapeForRootSlices(const HloComputation * fused_computation)440 StatusOr<Shape> GetConsistentInputShapeForRootSlices(
441 const HloComputation* fused_computation) {
442 const HloInstruction& root = *fused_computation->root_instruction();
443 if (root.opcode() == HloOpcode::kSlice) {
444 return root.operands()[0]->shape();
445 }
446
447 CHECK_EQ(root.opcode(), HloOpcode::kTuple);
448 const Shape& first_slice_operand_shape =
449 root.operands()[0]->operands()[0]->shape();
450 for (size_t i = 1; i < root.operands().size(); ++i) {
451 const HloInstruction* slice = root.operands()[i];
452 const Shape& operand_shape = slice->operands()[0]->shape();
453 if (!ShapeUtil::EqualIgnoringElementType(first_slice_operand_shape,
454 operand_shape)) {
455 return FailedPrecondition(
456 "Fused slices do not have the same input shape, fused computation = "
457 "%s.",
458 root.parent()->name());
459 }
460 }
461
462 return first_slice_operand_shape;
463 }
464
465 // Returns a sanitized (doesn't need quoting) identifier name from a location.
GetIrNameFromLoc(mlir::Location loc)466 std::string GetIrNameFromLoc(mlir::Location loc) {
467 return llvm_ir::SanitizeConstantName(mlir::GetNameFromLoc(loc));
468 }
469
470 // For a row reduction, returns the number of rows we can process in parallel
471 // per warp.
RowReductionGetRowsPerWarp(int reduced_dimension_size)472 int RowReductionGetRowsPerWarp(int reduced_dimension_size) {
473 if (WarpSize() % reduced_dimension_size != 0 ||
474 reduced_dimension_size >= WarpSize()) {
475 return 1;
476 }
477 return WarpSize() / reduced_dimension_size;
478 }
479
480 } // namespace
481
IrEmitterUnnested(const HloModuleConfig & hlo_module_config,IrEmitterContext * ir_emitter_context)482 IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
483 IrEmitterContext* ir_emitter_context)
484 : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false) {}
485
Create(const HloModuleConfig & hlo_module_config,IrEmitterContext * ir_emitter_context)486 StatusOr<std::unique_ptr<IrEmitterUnnested>> IrEmitterUnnested::Create(
487 const HloModuleConfig& hlo_module_config,
488 IrEmitterContext* ir_emitter_context) {
489 return std::unique_ptr<IrEmitterUnnested>(
490 new IrEmitterUnnested(hlo_module_config, ir_emitter_context));
491 }
492
BuildKernelPrototype(absl::string_view name,absl::Span<const BufferAllocation * const> args)493 llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
494 absl::string_view name, absl::Span<const BufferAllocation* const> args) {
495 // Compute the kernel name. The opcode string may contain "-" which cannot be
496 // in a PTX function name, so sanitize the name before uniquifying it.
497 std::string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName(
498 llvm_ir::SanitizeFunctionName(std::string(name)));
499
500 // Create the kernel and add it to the module.
501 llvm::LLVMContext& context = module_->getContext();
502 llvm::FunctionType* kernel_type = llvm::FunctionType::get(
503 /*Result=*/llvm::Type::getVoidTy(context),
504 std::vector<llvm::Type*>(args.size(), b_.getInt8PtrTy()),
505 /*isVarArg=*/false);
506 llvm::Function* kernel =
507 llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage,
508 kernel_name.c_str(), module_);
509
510 // Add dereferenceable and alignment information to each of the kernel's
511 // parameters.
512 auto arg_it = kernel->arg_begin();
513 for (size_t arg_no = 0; arg_no < args.size(); ++arg_no) {
514 const BufferAllocation* alloc = args[arg_no];
515 llvm::Argument& fn_arg = *arg_it;
516 ++arg_it;
517
518 kernel->addDereferenceableParamAttr(arg_no, alloc->size());
519
520 const int64_t alignment = [&] {
521 if (alloc->is_entry_computation_parameter()) {
522 return kEntryParameterAlignBytes;
523 } else if (alloc->is_constant()) {
524 return kConstantBufferAlignBytes;
525 } else {
526 return kXlaAllocatedBufferAlignBytes;
527 }
528 }();
529
530 kernel->addParamAttr(
531 arg_no,
532 llvm::Attribute::get(context, llvm::Attribute::Alignment, alignment));
533
534 if (alloc->IsPreallocatedTempBuffer()) {
535 fn_arg.setName("temp_buf");
536 } else {
537 fn_arg.setName(StrCat("alloc", alloc->index()));
538 }
539 }
540
541 AnnotateFunctionAsGpuKernel(module_, kernel, &b_);
542
543 // TODO(b/65380986): Investigate if adding fast math flags for generated
544 // kernels makes sense.
545
546 // Update the insert point to the entry basic block.
547 llvm::BasicBlock* entry_bb =
548 llvm::BasicBlock::Create(context, /*Name=*/"entry", /*Parent=*/kernel);
549
550 // Emit a "return void" at entry_bb's end, and set the insert point before
551 // that return instruction.
552 b_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb));
553
554 return kernel;
555 }
556
GetAllocationSlice(mlir::Value v,std::string * constant_name)557 StatusOr<BufferAllocation::Slice> IrEmitterUnnested::GetAllocationSlice(
558 mlir::Value v, std::string* constant_name) {
559 return xla::gpu::GetAllocationSlice(v, ir_emitter_context_->allocations(),
560 constant_name);
561 }
562
EmitConstant(mlir::Operation * op)563 Status IrEmitterUnnested::EmitConstant(mlir::Operation* op) {
564 auto get_global = mlir::cast<mlir::memref::GetGlobalOp>(op);
565 auto module = get_global->getParentOfType<mlir::ModuleOp>();
566 auto global = mlir::cast<mlir::memref::GlobalOp>(
567 module.lookupSymbol(get_global.getName()));
568
569 auto literal = global.getInitialValue()->dyn_cast<mlir::DenseElementsAttr>();
570 TF_RET_CHECK(literal);
571
572 const bool should_emit_initializer = literal.getType().getNumElements() <= 1;
573
574 TF_ASSIGN_OR_RETURN(int element_bytes,
575 GetElementTypeBytes(literal.getType().getElementType()));
576 llvm::ArrayType* global_type = llvm::ArrayType::get(
577 b_.getInt8Ty(), literal.getType().getNumElements() * element_bytes);
578
579 GpuExecutable::ConstantInfo info;
580 llvm::Constant* initializer;
581 if (should_emit_initializer) {
582 std::vector<uint8_t> content;
583 TF_RETURN_IF_ERROR(CopyDenseElementsDataToXlaFormat(literal, &content));
584 initializer =
585 llvm::ConstantDataArray::get<uint8_t>(module_->getContext(), content);
586 } else {
587 TF_RETURN_IF_ERROR(
588 CopyDenseElementsDataToXlaFormat(literal, &info.content));
589 initializer = llvm::ConstantAggregateZero::get(global_type);
590 }
591
592 // These globals will be looked up by name by GpuExecutable so we need to
593 // give them an external linkage. Not all of their uses are visible in
594 // the LLVM IR so we can't give then a linkage that merely preserves their
595 // names (like available_externally), we also need to ensure that they stick
596 // around even if they're "unused".
597 //
598 // We may have to be more clever here in the future if we notice that we're
599 // keeping around too many globals because of their linkage.
600 llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
601 global_type, /*isConstant=*/should_emit_initializer,
602 llvm::GlobalValue::ExternalLinkage,
603 /*Initializer=*/initializer, global.getSymName(),
604 /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
605 /*AddressSpace=*/0,
606 /*isExternallyInitialized=*/false);
607 global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes));
608 module_->getGlobalList().push_back(global_for_const);
609
610 info.symbol_name.assign(global.getSymName().begin(),
611 global.getSymName().end());
612
613 info.allocation_index =
614 global->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt();
615 ir_emitter_context_->constants().push_back(std::move(info));
616 return OkStatus();
617 }
618
GetConditionalThunkConfig(mlir::lmhlo::CaseOp op,std::vector<ThunkSequence> branch_thunk_sequences)619 static ConditionalThunkConfig GetConditionalThunkConfig(
620 mlir::lmhlo::CaseOp op, std::vector<ThunkSequence> branch_thunk_sequences) {
621 ConditionalThunkConfig config;
622 config.branch_index_is_bool = op.getIndex()
623 .getType()
624 .cast<mlir::ShapedType>()
625 .getElementType()
626 .isInteger(
627 /*width=*/1);
628 config.branch_count = op.getBranches().size();
629 // Pass nullptr as the HloInstruction* to the branch_thunks
630 // constructors because these SequentialThunks are logically "part of"
631 // this ConditionalThunk, and shouldn't be profiled separately from it.
632 config.branch_thunks.reserve(branch_thunk_sequences.size());
633 for (auto& branch_thunk_sequence : branch_thunk_sequences) {
634 config.branch_thunks.emplace_back(new SequentialThunk(
635 Thunk::ThunkInfo(), std::move(branch_thunk_sequence)));
636 }
637 return config;
638 }
639
EmitConditional(mlir::Operation * op)640 Status IrEmitterUnnested::EmitConditional(mlir::Operation* op) {
641 auto conditional = mlir::cast<mlir::lmhlo::CaseOp>(op);
642
643 std::vector<ThunkSequence> branch_thunks;
644
645 int branch_count = conditional.getBranches().size();
646 branch_thunks.reserve(branch_count);
647
648 for (int j = 0; j < branch_count; ++j) {
649 mlir::Region* branch_computation = &conditional.getBranches()[j];
650 TF_ASSIGN_OR_RETURN(
651 auto ir_emitter,
652 IrEmitterUnnested::Create(hlo_module_config_, ir_emitter_context_));
653 TF_RETURN_IF_ERROR(ir_emitter->EmitLmhloRegion(branch_computation));
654 branch_thunks.push_back(std::move(*ir_emitter->ConsumeThunkSequence()));
655 }
656
657 ConditionalThunkConfig config =
658 GetConditionalThunkConfig(conditional, std::move(branch_thunks));
659
660 TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(conditional.getIndex()));
661 AddThunkToThunkSequence(std::unique_ptr<Thunk>(
662 new ConditionalThunk(GetThunkInfo(op), std::move(config), slice)));
663 return OkStatus();
664 }
665
CreateLoad(llvm::Value * address,llvm::Type * data_type,int alignment_bytes)666 llvm::Value* IrEmitterUnnested::CreateLoad(llvm::Value* address,
667 llvm::Type* data_type,
668 int alignment_bytes) {
669 int data_bytes = data_type->getPrimitiveSizeInBits() /
670 primitive_util::BitWidth(PrimitiveType::U8);
671 if (alignment_bytes == 0) {
672 return b_.CreateLoad(data_type,
673 b_.CreateBitCast(address, data_type->getPointerTo()));
674 }
675
676 int alignment_bitwidth =
677 alignment_bytes * primitive_util::BitWidth(PrimitiveType::U8);
678
679 llvm::Value* output = llvm::ConstantInt::get(data_type, 0);
680 for (int offset_bytes = 0; offset_bytes < data_bytes;
681 offset_bytes += alignment_bytes) {
682 llvm::Value* offset_address = b_.CreateConstInBoundsGEP1_32(
683 b_.getInt8Ty(), address, offset_bytes, "offset_address");
684 llvm::Value* partial_value = b_.CreateLoad(b_.getIntNTy(alignment_bitwidth),
685 offset_address, "partial_value");
686 llvm::Value* zextd =
687 b_.CreateZExt(partial_value, output->getType(), "partial_value_zextd");
688 llvm::Value* shifted = b_.CreateShl(
689 zextd, llvm::ConstantInt::get(b_.getInt32Ty(), offset_bytes),
690 "partial_input_shifted");
691 output = b_.CreateAdd(output, shifted, "output_updated");
692 }
693 return output;
694 }
695
CreateStore(llvm::Value * data,llvm::Value * address,int alignment_bytes)696 void IrEmitterUnnested::CreateStore(llvm::Value* data, llvm::Value* address,
697 int alignment_bytes) {
698 int data_bytes = data->getType()->getPrimitiveSizeInBits() /
699 primitive_util::BitWidth(PrimitiveType::U8);
700 CHECK_GE(data_bytes, alignment_bytes);
701 if (alignment_bytes == 0) {
702 b_.CreateStore(data,
703 b_.CreateBitCast(address, data->getType()->getPointerTo()));
704 return;
705 }
706
707 int alignment_bitwidth =
708 alignment_bytes * primitive_util::BitWidth(PrimitiveType::U8);
709
710 for (int offset_bytes = 0; offset_bytes < data_bytes;
711 offset_bytes += alignment_bytes) {
712 llvm::Value* offset_address = b_.CreateConstInBoundsGEP1_32(
713 b_.getInt8Ty(), address, offset_bytes, "offset_address");
714 llvm::Value* shifted_partial = b_.CreateTrunc(
715 b_.CreateLShr(data,
716 llvm::ConstantInt::get(b_.getInt32Ty(), offset_bytes)),
717 b_.getIntNTy(alignment_bitwidth), "truncated_value");
718 b_.CreateStore(
719 shifted_partial,
720 b_.CreateBitCast(offset_address,
721 b_.getIntNTy(alignment_bitwidth)->getPointerTo()));
722 }
723 }
724
725 // Input = {dynamic array(with dynamic dimension meta data at the end)}
726 // Output = {static array, dynamic_dim0, dynamic_dim1}
EmitPadToStatic(mlir::Operation * op)727 Status IrEmitterUnnested::EmitPadToStatic(mlir::Operation* op) {
728 // TODO(jurahul): Create an op to represent PadToStatic.
729 auto pad_to_static = mlir::cast<mlir::lmhlo::CustomCallOp>(op);
730 int unroll_factor = 1;
731 std::string ir_name = GetIrNameFromLoc(pad_to_static.getLoc());
732
733 const Shape& input_shape = GetShape(pad_to_static.getArgs().front());
734 TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
735 CalculateLaunchDimensions(
736 input_shape, ir_emitter_context_->gpu_device_info(),
737 {unroll_factor}));
738 std::vector<llvm_ir::IrArray> ir_arrays;
739 TF_ASSIGN_OR_RETURN(auto kernel_thunk,
740 BuildKernelThunk(pad_to_static, GetThunkInfo(op),
741 &ir_arrays, launch_dimensions));
742
743 const llvm_ir::IrArray source_array = ir_arrays[0];
744 const llvm_ir::IrArray output_array = ir_arrays[1];
745 auto output_dim_arrays =
746 absl::Span<const llvm_ir::IrArray>(ir_arrays).subspan(2);
747
748 llvm::Type* index_ty = GetIndexTypeForKernel(
749 pad_to_static, launch_dimensions.launch_bound(), &b_);
750
751 // pseudo code for PadToStatic on a 2d array
752 // int* source_array = input[0];
753 // int* dest_array = output[0];
754 llvm::Value* source_buffer = source_array.GetBasePointer();
755 llvm::Value* raw_buffer =
756 b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo());
757
758 // TODO(jurahul): input_shape here is the static shape of the input (which has
759 // a dynamic shape in XLA). Currently, we are mapping that to a static shaped
760 // memref. When we change that to a more appropriate representation in MLIR,
761 // fix this code to correctly deduce the static shape backing the dynamically
762 // shaped memref.
763 int64_t raw_data_size = ShapeUtil::ByteSizeOf(input_shape);
764
765 // int* dyn_dim0_size = source_array + meta_data_offset;
766 // int* dyn_dim1_size = source_array + meta_data_offset + sizeof(int);
767 std::vector<llvm::Value*> dynamic_dims;
768 int alignment = raw_data_size % sizeof(int32_t);
769 for (int64_t i = 1; i < pad_to_static.getOutput().size(); ++i) {
770 // Dynamic size of each dimension is attached at the end of the source
771 // array(operand(0)). We need to extract these value.
772 const Shape& dim_shape = GetShape(pad_to_static.getOutput()[i]);
773 TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32)));
774
775 const int64_t dim_index = i - 1;
776 llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
777 b_.getInt8Ty(), raw_buffer,
778 raw_data_size + dim_index * sizeof(int32_t));
779 llvm::Value* dyn_dim_size =
780 CreateLoad(metadata, b_.getInt32Ty(), alignment);
781 dynamic_dims.push_back(dyn_dim_size);
782 }
783
784 // only one thread need to store the dynamic index
785 // int thread_id = GetThreadId();
786 // int block_id = GetBlockId();
787 // if (thread_id == 0 && block_id == 0) {
788 // *output[1] = *dyn_dim0_size;
789 // *output[2] = *dyn_dim1_size;
790 // }
791 KernelSupportLibrary{&b_}.If("is_thread_0", IsBlock0Thread0(&b_), [&] {
792 for (int64_t i = 1; i < pad_to_static.getOutput().size(); ++i) {
793 const int64_t dim_index = i - 1;
794 llvm::Value* dest_dim_size_address =
795 output_dim_arrays[dim_index].GetBasePointer();
796 // output[i] stores dynamic_dim_(i-1)
797 CreateStore(dynamic_dims[dim_index], dest_dim_size_address, alignment);
798 }
799 });
800
801 // int dyn_element_total = 1;
802 // dyn_element_total *= *dyn_dim0_size;
803 // dyn_element_total *= *dyn_dim1_size;
804 llvm::Value* dyn_element_total = llvm::ConstantInt::get(index_ty, 1);
805 for (llvm::Value* dynamic_dim : dynamic_dims) {
806 dyn_element_total =
807 b_.CreateMul(dyn_element_total,
808 b_.CreateIntCast(dynamic_dim, dyn_element_total->getType(),
809 /*isSigned=*/true),
810 /*Name=*/"dyn_element_total_pad");
811 }
812
813 // linear_index = block_id * threads_per_block + thread_id;
814 // if (linear_index < max_num_element) {
815 // Index static_index =
816 // delinerized(linerized_index, static_dim0_size, static_dim1_size);
817 // if (linerized_index < dyn_element_total) {
818 // Index dyn_index =
819 // delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
820 // dest_array[dyn_index.dim0][dyn_index.dim1] =
821 // source_array[static_index.dim0][static_index.dim1];
822 // }
823 // }
824 llvm_ir::BodyEmitter body_generator =
825 [&](const llvm_ir::IrArray::Index& array_index) -> Status {
826 llvm::Value* linearIndex =
827 array_index.Linearize(input_shape.dimensions(), &b_);
828 auto if_in_dyn_bounds = llvm_ir::EmitIfThenElse(
829 b_.CreateICmpULT(linearIndex, dyn_element_total),
830 llvm_ir::IrName(ir_name, "in_dyn_bounds"), &b_, false);
831 // Set IR builder insertion point to the body of the if structure.
832 llvm_ir::SetToFirstInsertPoint(if_in_dyn_bounds.true_block, &b_);
833 llvm_ir::IrArray::Index dyn_index(linearIndex, input_shape,
834 absl::MakeSpan(dynamic_dims), &b_);
835 output_array.EmitWriteArrayElement(
836 dyn_index,
837 source_array.EmitReadArrayElement(array_index, &b_, /*name=*/""), &b_,
838 /*use_linear_index=*/false);
839 return OkStatus();
840 };
841
842 const Shape& data_shape = GetShape(pad_to_static.getOutput().front());
843 TF_RETURN_IF_ERROR(ParallelLoopEmitter(body_generator, data_shape,
844 launch_dimensions, &b_,
845 {unroll_factor})
846 .EmitLoop(ir_name, index_ty));
847 thunk_sequence_.emplace_back(std::move(kernel_thunk));
848 return OkStatus();
849 }
850
851 // Input = {dynamic array(with dynamic dimension meta data at the end)}
852 // Output = {static array, dynamic_dim0, dynamic_dim1}
EmitSliceToDynamic(mlir::Operation * op)853 Status IrEmitterUnnested::EmitSliceToDynamic(mlir::Operation* op) {
854 // TODO(jurahul): Create an op to represent SliceToDynamic.
855 auto slice_to_dynamic = mlir::cast<mlir::lmhlo::CustomCallOp>(op);
856 int unroll_factor = 1;
857 std::string ir_name = GetIrNameFromLoc(slice_to_dynamic.getLoc());
858
859 const Shape& input_shape = GetShape(slice_to_dynamic.getArgs().front());
860 TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
861 CalculateLaunchDimensions(
862 input_shape, ir_emitter_context_->gpu_device_info(),
863 {unroll_factor}));
864 llvm::Type* index_ty = GetIndexTypeForKernel(
865 slice_to_dynamic, launch_dimensions.launch_bound(), &b_);
866 std::vector<llvm_ir::IrArray> ir_arrays;
867 TF_ASSIGN_OR_RETURN(auto kernel_thunk,
868 BuildKernelThunk(slice_to_dynamic, GetThunkInfo(op),
869 &ir_arrays, launch_dimensions));
870
871 TF_RET_CHECK(slice_to_dynamic.getOutput().size() == 1);
872 const Shape& data_shape = GetShape(slice_to_dynamic.getOutput().front());
873
874 // TODO(jurahul): data_shape here is the static shape of the output (which has
875 // a dynamic shape in XLA). Currently, we are mapping that to a static shaped
876 // memref. When we change that to a more appropriate representation in MLIR,
877 // fix this code to correctly deduce the static shape backing the dynamically
878 // shaped memref.
879
880 // calculate the location where metadata needs to be inserted
881 // int* dyn_dim0_size = dest_array + meta_data_offset;
882 // int* dyn_dim1_size = dest_array + meta_data_offset + sizeof(int);
883 int32_t raw_data_size = ShapeUtil::ByteSizeOf(data_shape);
884
885 // pseudo code for sliceToDynamic on a 2d array
886 // int* source_array = input[0];
887 // int* dest_array = output[0];
888 const llvm_ir::IrArray data_array = ir_arrays.back();
889 llvm::Value* dest_buffer = data_array.GetBasePointer();
890 llvm::Value* raw_buffer =
891 b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo());
892
893 // Load dynamic dimensions from memory.
894 std::vector<llvm::Value*> dynamic_dims;
895 int alignment = raw_data_size % sizeof(int32_t);
896 for (int64_t i = 1; i < slice_to_dynamic.getArgs().size(); ++i) {
897 // const int64_t dim_index = i - 1;
898 llvm::Value* source_buffer = ir_arrays[i].GetBasePointer();
899 llvm::Type* source_buffer_pointee_type = ir_arrays[i].GetBasePointeeType();
900 llvm::LoadInst* dyn_dim_size =
901 Load(source_buffer_pointee_type, source_buffer, "dyn_dim_size");
902 dynamic_dims.push_back(dyn_dim_size);
903 }
904
905 // only one thread need to store the dynamic index
906 // int thread_id = GetThreadId();
907 // int block_id = GetBlockId();
908 // if (thread_id == 0 && block_id == 0) {
909 // *dyn_dim0_size = *output[1];
910 // *dyn_dim1_size = *output[2];
911 // }
912 KernelSupportLibrary{&b_}.If("is_thread_0", IsBlock0Thread0(&b_), [&] {
913 for (int64_t i = 1; i < slice_to_dynamic.getArgs().size(); ++i) {
914 const int64_t dim_index = i - 1;
915 llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
916 b_.getInt8Ty(), raw_buffer,
917 raw_data_size + dim_index * sizeof(int32_t));
918 // output[i] stores dynamic_dim_(i-1)
919 CreateStore(dynamic_dims[dim_index], metadata, alignment);
920 }
921 });
922
923 // int dyn_element_total = 1;
924 // dyn_element_total *= dyn_dim0_size;
925 // dyn_element_total *= dyn_dim1_size;
926 llvm::Value* dyn_element_total = llvm::ConstantInt::get(index_ty, 1);
927 for (llvm::Value* dynamic_dim : dynamic_dims) {
928 dyn_element_total =
929 b_.CreateMul(dyn_element_total,
930 b_.CreateIntCast(dynamic_dim, dyn_element_total->getType(),
931 /*isSigned=*/true),
932 /*Name=*/"dyn_element_total_slice");
933 }
934
935 // linear_index = block_id * threads_per_block + thread_id;
936 // if (linear_index < max_num_element) {
937 // Index static_index =
938 // delinerized(linerized_index, static_dim0_size, static_dim1_size);
939 // if (linerized_index < dyn_element_total) {
940 // Index dyn_index =
941 // delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
942 // dest_array[static_index.dim0][static_index.di] =
943 // source_array[dyn_index.dim0][dyn_index.dim1];
944 // }
945 // }
946 llvm_ir::BodyEmitter body_generator =
947 [&](const llvm_ir::IrArray::Index& array_index) -> Status {
948 llvm::Value* linearIndex =
949 array_index.Linearize(input_shape.dimensions(), &b_);
950 auto if_in_dyn_bounds = llvm_ir::EmitIfThenElse(
951 b_.CreateICmpULT(linearIndex, dyn_element_total),
952 llvm_ir::IrName(ir_name, "in_dyn_bounds"), &b_, false);
953 // Set IR builder insertion point to the body of the if structure.
954 llvm_ir::SetToFirstInsertPoint(if_in_dyn_bounds.true_block, &b_);
955 llvm_ir::IrArray::Index dyn_index(linearIndex, input_shape,
956 absl::MakeSpan(dynamic_dims), &b_);
957
958 data_array.EmitWriteArrayElement(
959 array_index,
960 ir_arrays[0].EmitReadArrayElement(dyn_index, &b_, /*name=*/"",
961 /*use_linear_index=*/false),
962 &b_);
963 return OkStatus();
964 };
965
966 TF_RETURN_IF_ERROR(ParallelLoopEmitter(body_generator, data_shape,
967 launch_dimensions, &b_,
968 {unroll_factor})
969 .EmitLoop(ir_name, index_ty));
970 thunk_sequence_.emplace_back(std::move(kernel_thunk));
971 return OkStatus();
972 }
973
EmitConvolutionThunk(mlir::Operation * op)974 Status IrEmitterUnnested::EmitConvolutionThunk(mlir::Operation* op) {
975 using mlir::dyn_cast;
976 using mlir::lmhlo_gpu::Activation;
977 using mlir::lmhlo_gpu::ConvBackwardFilterOp;
978 using mlir::lmhlo_gpu::ConvBackwardInputOp;
979 using mlir::lmhlo_gpu::ConvForwardFusedOp;
980 using mlir::lmhlo_gpu::ConvForwardFusedSideInputOp;
981 using mlir::lmhlo_gpu::ConvForwardOp;
982
983 // Last 2 operands of the convolution operation are the result and scratch.
984 std::vector<BufferAllocation::Slice> operand_slices;
985 int64_t num_operands = op->getNumOperands();
986 operand_slices.reserve(num_operands - 2);
987 for (mlir::Value operand : op->getOperands().drop_back(2)) {
988 TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(operand));
989 operand_slices.push_back(slice);
990 }
991
992 mlir::Value conv_result = op->getOperand(num_operands - 2);
993 mlir::Value scratch_result = op->getOperand(num_operands - 1);
994 TF_ASSIGN_OR_RETURN(auto conv_result_slice, GetAllocationSlice(conv_result));
995 TF_ASSIGN_OR_RETURN(auto scratch_slice, GetAllocationSlice(scratch_result));
996
997 auto apply_layout = [](const Shape& shape,
998 mlir::ArrayRef<int64_t> minor_to_major) {
999 return ShapeUtil::MakeShapeWithLayout(shape.element_type(),
1000 shape.dimensions(), minor_to_major);
1001 };
1002
1003 GpuConvDescriptor descriptor;
1004
1005 auto fill_conv_descriptor = [&](auto op) {
1006 descriptor.operand0_shape =
1007 apply_layout(GetShape(op->getOperand(0)),
1008 op.getBackendConfig().getOperand_0Layout());
1009 descriptor.operand1_shape =
1010 apply_layout(GetShape(op->getOperand(1)),
1011 op.getBackendConfig().getOperand_1Layout());
1012 descriptor.result_shape = apply_layout(
1013 GetShape(conv_result), op.getBackendConfig().getResultLayout());
1014 descriptor.dnums = ConvertConvDimensionNumbers(op.getDimensionNumbers());
1015 descriptor.scratch_size = scratch_slice.size();
1016 mlir::DenseIntElementsAttr window_strides =
1017 op.getWindowStrides().getValue();
1018 mlir::DenseIntElementsAttr padding = op.getPadding().getValue();
1019 mlir::DenseIntElementsAttr lhs_dilation = op.getLhsDilation().getValue();
1020 mlir::DenseIntElementsAttr rhs_dilation = op.getRhsDilation().getValue();
1021 mlir::DenseElementsAttr window_reversal = op.getWindowReversal().getValue();
1022 for (auto index : llvm::seq<int>(0, window_strides.getNumElements())) {
1023 WindowDimension* dim = descriptor.window.add_dimensions();
1024 // Window size for a convolution is the same as the kernel size.
1025 // Kernel size of the convolution is operand1_shape. We need to look at
1026 // the convolution dimension numbers kernel spatial dimensions to get
1027 // the window size.
1028 int kernel_dim = descriptor.dnums.kernel_spatial_dimensions(index);
1029 dim->set_size(descriptor.operand0_shape.dimensions(kernel_dim));
1030 dim->set_stride(window_strides.getValues<int64_t>()[index]);
1031 dim->set_padding_low(padding.getValues<int64_t>()[index]);
1032 dim->set_padding_high(padding.getValues<int64_t>()[index]);
1033 dim->set_base_dilation(lhs_dilation.getValues<int64_t>()[index]);
1034 dim->set_window_dilation(rhs_dilation.getValues<int64_t>()[index]);
1035 dim->set_window_reversal(window_reversal.getValues<bool>()[index]);
1036 }
1037 descriptor.feature_group_count = op.getFeatureGroupCount();
1038 {
1039 auto* algorithm = descriptor.backend_config.mutable_algorithm();
1040 algorithm->set_algo_id(op.getBackendConfig().getAlgorithm());
1041 algorithm->set_math_type(op.getBackendConfig().getTensorOpsEnabled()
1042 ? se::dnn::AlgorithmProto::TENSOR_OP_MATH
1043 : se::dnn::AlgorithmProto::DEFAULT_MATH);
1044 for (int i = 0; i < op.getBackendConfig().getKnobIds().size(); ++i) {
1045 // N.B. tuning_knobs is a map rather than a repeated field, so this
1046 // doesn't require reserving space up front.
1047 (*algorithm
1048 ->mutable_tuning_knobs())[op.getBackendConfig().getKnobIds()[i]] =
1049 op.getBackendConfig().getKnobValues()[i];
1050 }
1051 algorithm->set_is_cudnn_frontend(
1052 op.getBackendConfig().getIsCudnnFrontend());
1053 auto workspace_size = op.getBackendConfig().getWorkspaceSize();
1054 if (workspace_size >= 0) {
1055 algorithm->mutable_workspace_size()->set_value(workspace_size);
1056 }
1057 }
1058 descriptor.backend_config.set_conv_result_scale(
1059 op.getResultScale().convertToDouble());
1060 };
1061
1062 auto set_activation_mode = [&](auto op) -> Status {
1063 TF_ASSIGN_OR_RETURN(stream_executor::dnn::ActivationMode activation_mode,
1064 ConvertConvActivationMode(op.getActivationMode()));
1065 descriptor.backend_config.set_activation_mode(
1066 static_cast<int64_t>(activation_mode));
1067 return OkStatus();
1068 };
1069
1070 if (auto conv = dyn_cast<ConvForwardOp>(op)) {
1071 descriptor.kind = CudnnConvKind::kForward;
1072 fill_conv_descriptor(conv);
1073 } else if (auto conv = dyn_cast<ConvBackwardInputOp>(op)) {
1074 descriptor.kind = CudnnConvKind::kBackwardInput;
1075 fill_conv_descriptor(conv);
1076 } else if (auto conv = dyn_cast<ConvBackwardFilterOp>(op)) {
1077 descriptor.kind = CudnnConvKind::kBackwardFilter;
1078 fill_conv_descriptor(conv);
1079 } else if (auto conv = dyn_cast<ConvForwardFusedOp>(op)) {
1080 descriptor.kind = CudnnConvKind::kForwardActivation;
1081 fill_conv_descriptor(conv);
1082 TF_RETURN_IF_ERROR(set_activation_mode(conv));
1083 } else if (auto conv = dyn_cast<ConvForwardFusedSideInputOp>(op)) {
1084 descriptor.kind = CudnnConvKind::kForwardActivation;
1085 fill_conv_descriptor(conv);
1086 TF_RETURN_IF_ERROR(set_activation_mode(conv));
1087 descriptor.backend_config.set_side_input_scale(
1088 conv.getSideInputScale().convertToDouble());
1089 } else {
1090 return InternalError("Unexpected operation");
1091 }
1092 TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(descriptor, ""));
1093 AddThunkToThunkSequence(std::make_unique<ConvolutionThunk>(
1094 GetThunkInfo(op), std::move(config), std::move(operand_slices),
1095 conv_result_slice, scratch_slice));
1096 return OkStatus();
1097 }
1098
EmitGemmThunk(mlir::Operation * op)1099 Status IrEmitterUnnested::EmitGemmThunk(mlir::Operation* op) {
1100 auto gemm = mlir::dyn_cast<mlir::lmhlo_gpu::GEMMOp>(op);
1101 TF_RET_CHECK(gemm != nullptr);
1102
1103 TF_ASSIGN_OR_RETURN(auto a, GetAllocationSlice(gemm.getA()));
1104 TF_ASSIGN_OR_RETURN(auto b, GetAllocationSlice(gemm.getB()));
1105 TF_ASSIGN_OR_RETURN(auto c, GetAllocationSlice(gemm.getC()));
1106
1107 TF_ASSIGN_OR_RETURN(GemmConfig config, GemmConfig::For(gemm));
1108 auto thunk =
1109 std::make_unique<GemmThunk>(GetThunkInfo(op), std::move(config), a, b, c);
1110
1111 AddThunkToThunkSequence(std::move(thunk));
1112 return OkStatus();
1113 }
1114
1115 #if GOOGLE_CUDA
1116
EmitCublasLtMatmulThunk(mlir::Operation * op)1117 Status IrEmitterUnnested::EmitCublasLtMatmulThunk(mlir::Operation* op) {
1118 auto matmul = mlir::dyn_cast<mlir::lmhlo_gpu::CublasLtMatmulOp>(op);
1119 TF_RET_CHECK(matmul != nullptr);
1120
1121 TF_ASSIGN_OR_RETURN(auto a, GetAllocationSlice(matmul.getA()));
1122 TF_ASSIGN_OR_RETURN(auto b, GetAllocationSlice(matmul.getB()));
1123 TF_ASSIGN_OR_RETURN(auto c, GetAllocationSlice(matmul.getC()));
1124 TF_ASSIGN_OR_RETURN(auto d, GetAllocationSlice(matmul.getD()));
1125
1126 BufferAllocation::Slice bias;
1127 if (matmul.getBias() != nullptr) {
1128 TF_ASSIGN_OR_RETURN(bias, GetAllocationSlice(matmul.getBias()));
1129 }
1130
1131 TF_ASSIGN_OR_RETURN(cublas_lt::MatmulPlan plan,
1132 cublas_lt::MatmulPlan::For(matmul));
1133 auto thunk = std::make_unique<CublasLtMatmulThunk>(
1134 GetThunkInfo(op), std::move(plan), matmul.getAlgorithm(), a, b, c, d,
1135 bias);
1136
1137 AddThunkToThunkSequence(std::move(thunk));
1138 return OkStatus();
1139 }
1140
1141 #endif // GOOGLE_CUDA
1142
1143 namespace {
1144 // An MLIR value and its name as defined in the ODS spec.
1145 struct NamedValue {
1146 mlir::Value value;
1147 absl::string_view name;
1148 };
1149
1150 // Determine if we enable the row optimized codegen. When we have a
1151 // fusion with only point-wise operations, scalar broadcasting and row
1152 // broadcasting, we can trigger a kernel that vectorize the row loads.
1153 // This speed up the kernel, in particular on A100.
1154 // Returns a pair<bool, int>. The bool mean should we try to enable
1155 // row vectorization. The int is the number of inputs with the higher
1156 // rank.
RowVectorizationEnabled(mlir::lmhlo::FusionOp fusion)1157 std::pair<bool, int> RowVectorizationEnabled(mlir::lmhlo::FusionOp fusion) {
1158 const auto is_row_major = [](mlir::Value value) {
1159 // Only tested when the inputs are row-major. So only
1160 // enable that case. Maybe it would works if only the
1161 // inner dimensions is contiguous.
1162 return LayoutUtil::IsMonotonicWithDim0Major(GetShape(value).layout());
1163 };
1164 bool row_vectorized =
1165 fusion.getFusionResults().size() == 1 && // Not tested with MOF.
1166 absl::c_all_of(GetHloOperands(fusion), is_row_major) &&
1167 absl::c_all_of(GetHloOutputs(fusion), is_row_major);
1168
1169 // Check that the operations in the fusion are supported. Each
1170 // supported operation (or category) must be manually vetted as XLA
1171 // only unrolls and relies on LLVM to vectorize. But this is brittle.
1172 // Currently tested and supported operations:
1173 // Elementwise, scalar and row broadcasting.
1174 //
1175 // We also detect at the same time if there is a row broadcasting
1176 // operation.
1177 bool some_row_broadcasting = false;
1178 auto out_rank =
1179 fusion.getFusionResults()[0].getType().cast<mlir::ShapedType>().getRank();
1180 int num_big_inputs = 0;
1181 for (mlir::Operation& op : fusion.getRegion().front()) {
1182 if (auto load = mlir::dyn_cast<mlir::bufferization::ToTensorOp>(op)) {
1183 auto rank = load.getResult().getType().cast<mlir::ShapedType>().getRank();
1184 num_big_inputs += static_cast<int>(rank == out_rank);
1185 continue;
1186 } else if (mlir::isa<mlir::memref::TensorStoreOp, mlir::lmhlo::TerminatorOp,
1187 mlir::mhlo::ReturnOp, mlir::mhlo::ConstantOp,
1188 mlir::lmhlo::ConstantOp>(op)) {
1189 continue;
1190 }
1191 HloOpcode opcode = *MhloToHloOpcode(&op);
1192 if (HloInstruction::IsOpElementwise(opcode)) {
1193 continue;
1194 }
1195
1196 if (auto broadcast = mlir::dyn_cast<mlir::mhlo::BroadcastInDimOp>(op)) {
1197 const auto& broadcast_dimensions_size =
1198 broadcast.broadcast_dimensions().size();
1199 if (broadcast_dimensions_size == 0) {
1200 continue;
1201 }
1202 llvm::SmallVector<int64_t> broadcast_dimensions;
1203 broadcast_dimensions.reserve(broadcast_dimensions_size);
1204 for (const llvm::APInt& int_value : broadcast.broadcast_dimensions()) {
1205 broadcast_dimensions.push_back(int_value.getSExtValue());
1206 }
1207
1208 auto rank = GetShape(broadcast.getResult()).rank();
1209 if (broadcast_dimensions.size() == 1 &&
1210 broadcast_dimensions.back() == (rank - 1)) {
1211 some_row_broadcasting = true;
1212 continue;
1213 }
1214 }
1215 VLOG(2) << "Row vectorization not enabled due to this op: "
1216 << MlirToString(&op);
1217 return std::make_pair(false, 0);
1218 }
1219 // Trigger only when there is a row broadcasting.
1220 return std::make_pair(row_vectorized && some_row_broadcasting,
1221 num_big_inputs);
1222 }
1223 } // namespace
1224
1225 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
EmitCholeskyThunk(mlir::Operation * op)1226 Status IrEmitterUnnested::EmitCholeskyThunk(mlir::Operation* op) {
1227 auto cholesky_op = mlir::cast<mlir::lmhlo_gpu::CholeskyOp>(op);
1228
1229 const Shape shape = GetShape(cholesky_op.getInput());
1230 int ndim = shape.dimensions_size();
1231 CHECK_GE(ndim, 2);
1232 int64_t n = shape.dimensions(ndim - 1);
1233
1234 const auto& dims = shape.dimensions();
1235 int64_t batch_size =
1236 std::accumulate(dims.begin(), dims.end() - 2, int64_t{1},
1237 [](int64_t a, int64_t b) { return a * b; });
1238
1239 TF_ASSIGN_OR_RETURN(auto operand_buffer,
1240 GetAllocationSlice(cholesky_op.getInput()));
1241 TF_ASSIGN_OR_RETURN(auto a_buffer,
1242 GetAllocationSlice(cholesky_op.getOutput()));
1243 TF_ASSIGN_OR_RETURN(auto workspace_buffer,
1244 GetAllocationSlice(cholesky_op.getScratch()));
1245 TF_ASSIGN_OR_RETURN(auto info_buffer,
1246 GetAllocationSlice(cholesky_op.getInfo()));
1247
1248 ThunkSequence thunks;
1249
1250 if (operand_buffer != a_buffer) {
1251 thunks.push_back(std::make_unique<DeviceToDeviceCopyThunk>(
1252 GetThunkInfo(op),
1253 /*source_address=*/operand_buffer,
1254 /*destination_buffer=*/a_buffer,
1255 /*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
1256 }
1257
1258 CholeskyOptions options;
1259 options.set_lower(cholesky_op.getIsLower());
1260 thunks.push_back(std::make_unique<CholeskyThunk>(
1261 GetThunkInfo(op), options,
1262 PtxOptsFromDebugOptions(hlo_module_config_.debug_options()), a_buffer,
1263 workspace_buffer, info_buffer, shape.element_type(), batch_size, n));
1264
1265 // Elide the sequential thunk if there's no copy.
1266 if (thunks.size() == 1) {
1267 AddThunkToThunkSequence(std::move(thunks[0]));
1268 } else {
1269 AddThunkToThunkSequence(
1270 std::make_unique<SequentialThunk>(GetThunkInfo(op), std::move(thunks)));
1271 }
1272
1273 return OkStatus();
1274 }
1275 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1276
EmitCustomCallThunk(mlir::Operation * op)1277 Status IrEmitterUnnested::EmitCustomCallThunk(mlir::Operation* op) {
1278 auto custom_call = mlir::cast<mlir::lmhlo::CustomCallOp>(op);
1279 const std::string call_target_name = custom_call.getCallTargetName().str();
1280
1281 void* call_target = CustomCallTargetRegistry::Global()->Lookup(
1282 call_target_name, std::string(platform_name()));
1283 if (!call_target) {
1284 return Unimplemented(
1285 "No registered implementation for custom call to \"%s\"",
1286 call_target_name);
1287 }
1288
1289 std::vector<CustomCallThunk::OptionalSlice> operands;
1290 std::vector<CustomCallThunk::OptionalSlice> results;
1291 if (custom_call.getTargetArgMapping()) {
1292 auto values_to_slices_with_token_holes =
1293 [&](mlir::ValueRange operands,
1294 mlir::ArrayRef<int64_t> op_to_target_mapping, int64_t num_target)
1295 -> StatusOr<std::vector<CustomCallThunk::OptionalSlice>> {
1296 std::vector<CustomCallThunk::OptionalSlice> slices(num_target);
1297 for (auto index_and_value_it :
1298 llvm::zip(op_to_target_mapping, operands)) {
1299 int64_t index = std::get<0>(index_and_value_it);
1300 mlir::Value value = std::get<1>(index_and_value_it);
1301 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
1302 GetAllocationSlice(value));
1303 slices[index] = slice;
1304 }
1305 return slices;
1306 };
1307
1308 mlir::lmhlo::CustomCallTargetArgMappingAttr target_mapping =
1309 *custom_call.getTargetArgMapping();
1310 TF_ASSIGN_OR_RETURN(operands, values_to_slices_with_token_holes(
1311 custom_call.getArgs(),
1312 target_mapping.getArgsToTargetArgs(),
1313 target_mapping.getNumArgs()));
1314 TF_ASSIGN_OR_RETURN(results, values_to_slices_with_token_holes(
1315 custom_call.getOutput(),
1316 target_mapping.getResultsToTargetResults(),
1317 target_mapping.getNumResults()));
1318 } else {
1319 auto values_to_slices = [&](mlir::ValueRange values)
1320 -> StatusOr<std::vector<CustomCallThunk::OptionalSlice>> {
1321 std::vector<CustomCallThunk::OptionalSlice> slices;
1322 for (mlir::Value value : values) {
1323 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
1324 GetAllocationSlice(value));
1325 slices.push_back(slice);
1326 }
1327 return slices;
1328 };
1329
1330 TF_ASSIGN_OR_RETURN(operands, values_to_slices(custom_call.getArgs()));
1331 TF_ASSIGN_OR_RETURN(results, values_to_slices(custom_call.getOutput()));
1332 }
1333
1334 CustomCallThunk::CustomCallTarget custom_call_target;
1335
1336 // For information about this calling convention, see
1337 // xla/g3doc/custom_call.md.
1338 switch (custom_call.getApiVersion()) {
1339 case mlir::mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL:
1340 using original_call_type =
1341 void (*)(CustomCallThunk::Stream /*stream*/, void** /*buffers*/,
1342 const char* /*opaque*/, size_t /*opaque_len*/);
1343 custom_call_target = [call_target](CustomCallThunk::Stream stream,
1344 void** buffers, const char* opaque,
1345 size_t opaque_len,
1346 XlaCustomCallStatus*) {
1347 auto typed_call_target =
1348 reinterpret_cast<original_call_type>(call_target);
1349 typed_call_target(stream, buffers, opaque, opaque_len);
1350 };
1351 break;
1352 case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING:
1353 case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED:
1354 using status_returning_call_type =
1355 void (*)(CustomCallThunk::Stream /*stream*/, void** /*buffers*/,
1356 const char* /*opaque*/, size_t /*opaque_len*/,
1357 XlaCustomCallStatus* /*status*/);
1358 custom_call_target =
1359 reinterpret_cast<status_returning_call_type>(call_target);
1360 break;
1361 default:
1362 return InternalError("Unknown custom-call API version enum value: %d",
1363 custom_call.getApiVersion());
1364 }
1365
1366 auto thunk = std::make_unique<CustomCallThunk>(
1367 GetThunkInfo(op), std::move(custom_call_target), std::move(operands),
1368 std::move(results), custom_call.getBackendConfig().str());
1369 AddThunkToThunkSequence(std::move(thunk));
1370 return OkStatus();
1371 }
1372
EmitFftThunk(mlir::Operation * op)1373 Status IrEmitterUnnested::EmitFftThunk(mlir::Operation* op) {
1374 auto fft_op = mlir::cast<mlir::lmhlo::FftOp>(op);
1375 const Shape operand_shape = GetShape(fft_op.getOperand());
1376 const Shape output_shape = GetShape(fft_op.getOutput());
1377 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand_shape.layout()));
1378 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(output_shape.layout()));
1379
1380 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice arg_slice,
1381 GetAllocationSlice(fft_op.getOperand()));
1382 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dest_slice,
1383 GetAllocationSlice(fft_op.getOutput()));
1384 TF_ASSIGN_OR_RETURN(
1385 xla::FftType fft_type,
1386 ConvertFftType(mlir::mhlo::stringifyFftType(fft_op.getFftType())));
1387 auto fft_length_values = fft_op.getFftLength().getValues<int64_t>();
1388 std::vector<int64_t> fft_length(fft_length_values.begin(),
1389 fft_length_values.end());
1390
1391 AddThunkToThunkSequence(
1392 std::make_unique<FftThunk>(GetThunkInfo(op), fft_type, fft_length,
1393 /*input_buffer=*/arg_slice,
1394 /*output_buffer=*/dest_slice,
1395 /*input_shape=*/operand_shape,
1396 /*output_shape=*/output_shape));
1397 return OkStatus();
1398 }
1399
1400 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
EmitTriangularSolveCustomCall(mlir::Operation * op)1401 Status IrEmitterUnnested::EmitTriangularSolveCustomCall(mlir::Operation* op) {
1402 auto custom_call = mlir::cast<mlir::lmhlo::CustomCallOp>(op);
1403
1404 auto operands = op->getOperands();
1405 TF_RET_CHECK(operands.size() == 4);
1406
1407 // We expect Fortran layout for everything other than the temp buffer (the
1408 // last operand). Fortran layout is not XLA default layout with elements 0
1409 // and 1 swapped. For example instead of default layout {3,2,1,0} we'd have
1410 // Fortran layout {2,3,1,0}.
1411 TF_RET_CHECK(absl::c_all_of(operands.drop_back(1), [&](mlir::Value v) {
1412 const Shape& shape = GetShape(v);
1413 const Layout& layout = shape.layout();
1414 int n = layout.minor_to_major_size();
1415 if (n < 2) {
1416 return false;
1417 }
1418 // Unfortunately the HLO -> LMHLO -> HLO conversion loses layout information
1419 // if the shape has any dimensions of size 1: In that case, the new HLO
1420 // (which we see here) will have an arbitrary value for the location of the
1421 // size-1 dimension. Just skip this assertion if the shape has any
1422 // degenerate dimensions.
1423 if (absl::c_any_of(shape.dimensions(),
1424 [](int64_t dim) { return dim == 1; })) {
1425 return true;
1426 }
1427 return layout.minor_to_major(0) == n - 2 &&
1428 layout.minor_to_major(1) == n - 1 &&
1429 std::is_sorted(layout.minor_to_major().begin() + 2,
1430 layout.minor_to_major().end(),
1431 std::greater<int64_t>());
1432 }));
1433
1434 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice a_slice,
1435 GetAllocationSlice(operands[0]));
1436 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice b_slice,
1437 GetAllocationSlice(operands[1]));
1438 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
1439 GetAllocationSlice(operands[2]));
1440 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice temp_slice,
1441 GetAllocationSlice(operands[3]));
1442
1443 const Shape b_shape = GetShape(operands[1]);
1444 const PrimitiveType elem_ty = b_shape.element_type();
1445 TriangularSolveOptions backend_config;
1446 TF_RETURN_IF_ERROR(tensorflow::HumanReadableJsonToProto(
1447 custom_call.getBackendConfig().str(), &backend_config));
1448
1449 ThunkSequence thunks;
1450
1451 // Triangular solve is in-place on 'b', so copy 'b' to the output if they
1452 // aren't the same buffer.
1453 if (b_slice != result_slice) {
1454 thunks.push_back(std::make_unique<DeviceToDeviceCopyThunk>(
1455 Thunk::ThunkInfo(),
1456 /*source_address=*/b_slice,
1457 /*destination_buffer=*/result_slice,
1458 /*mem_size=*/ShapeUtil::ByteSizeOf(b_shape)));
1459 }
1460
1461 int64_t m = b_shape.dimensions(b_shape.rank() - 2);
1462 int64_t n = b_shape.dimensions(b_shape.rank() - 1);
1463 int64_t batch_size = std::accumulate(
1464 b_shape.dimensions().begin(), b_shape.dimensions().end() - 2, int64_t{1},
1465 [](int64_t a, int64_t b) { return a * b; });
1466 int64_t elem_size = ShapeUtil::ByteSizeOfPrimitiveType(elem_ty);
1467 int64_t a_batch_stride =
1468 backend_config.left_side() ? m * m * elem_size : n * n * elem_size;
1469 int64_t b_batch_stride = m * n * elem_size;
1470 thunks.push_back(std::make_unique<TriangularSolveThunk>(
1471 GetThunkInfo(op), backend_config,
1472 PtxOptsFromDebugOptions(hlo_module_config_.debug_options()),
1473 /*a_buffer=*/a_slice, /*b_buffer=*/result_slice, temp_slice, elem_ty,
1474 batch_size, m, n, a_batch_stride, b_batch_stride));
1475
1476 // Elide the sequential thunk if there's no copy.
1477 if (thunks.size() == 1) {
1478 AddThunkToThunkSequence(std::move(thunks[0]));
1479 } else {
1480 AddThunkToThunkSequence(
1481 std::make_unique<SequentialThunk>(GetThunkInfo(op), std::move(thunks)));
1482 }
1483 return OkStatus();
1484 }
1485 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1486
1487 // Convert the following form of fusion region:
1488 // fusion() {
1489 // %0 = tensor_load %external_memref0
1490 // %1 = tensor_load %external_memref1
1491 // ...
1492 // tensor_store %ret, %external_memref2
1493 // }
1494 // to
1495 // fusion(%external_memref0, %external_memref1) (^bb(%0, %1) {
1496 // ...
1497 // mhlo.return %ret
1498 // })
1499 //
1500 // So that it's suitable for MHLO -> XLA HLO conversion.
1501 // This function won't be needed once ElementalIrEmitter migrates to take MHLO
1502 // instead.
ProcessFusionForConversion(mlir::Region * region,std::vector<Shape> * operand_shapes,std::vector<Shape> * output_shapes)1503 static Status ProcessFusionForConversion(mlir::Region* region,
1504 std::vector<Shape>* operand_shapes,
1505 std::vector<Shape>* output_shapes) {
1506 std::vector<mlir::bufferization::ToTensorOp> loads;
1507 std::vector<mlir::memref::TensorStoreOp> stores;
1508
1509 region->walk([&](mlir::bufferization::ToTensorOp load) {
1510 if (load.getMemref().getParentRegion() != region) {
1511 loads.push_back(load);
1512 }
1513 });
1514
1515 region->walk([&](mlir::memref::TensorStoreOp store) {
1516 if (store.getMemref().getParentRegion() != region) {
1517 stores.push_back(store);
1518 }
1519 });
1520
1521 for (auto& load : loads) {
1522 auto arg = region->addArgument(load.getType(), region->getLoc());
1523 load.replaceAllUsesWith(arg);
1524 Shape shape = GetShape(load.getResult());
1525 operand_shapes->push_back(std::move(shape));
1526 load.erase();
1527 }
1528
1529 std::vector<mlir::Value> returned_values;
1530 for (auto store : stores) {
1531 Shape shape = GetShape(store.getMemref());
1532 output_shapes->push_back(shape);
1533
1534 returned_values.push_back(store.getTensor());
1535 store.erase();
1536 }
1537
1538 region->back().back().erase();
1539 auto b = mlir::OpBuilder::atBlockEnd(®ion->back());
1540 auto loc = returned_values[0].getLoc();
1541 b.create<mlir::mhlo::ReturnOp>(loc, returned_values);
1542 return OkStatus();
1543 }
1544
1545 // We can iterate the output buffer in logical order instead of physical order
1546 // when it is safe and profitable to do so.
EnableLogicalIndexGenerationForOutput(LaunchDimensionsConfig launch_config,LaunchDimensions launch_dimensions,absl::Span<llvm_ir::IrArray> operand_arrays,absl::Span<llvm_ir::IrArray> output_arrays)1547 static bool EnableLogicalIndexGenerationForOutput(
1548 LaunchDimensionsConfig launch_config, LaunchDimensions launch_dimensions,
1549 absl::Span<llvm_ir::IrArray> operand_arrays,
1550 absl::Span<llvm_ir::IrArray> output_arrays) {
1551 // Safety checks. Currently the logical index generation code has
1552 // limitations. Violating these conditions can give wrong output.
1553 if (launch_config.row_vectorized || launch_config.few_waves) return false;
1554 if (output_arrays.size() != 1) return false;
1555 const Shape& output_shape = output_arrays[0].GetShape();
1556 if (output_shape.is_dynamic()) return false;
1557 // Currently we require that the number of threads * unroll factor should
1558 // exactly equal the number of output elements. It simplifies bounds checking
1559 // within the LLVM code generated for the fused kernel.
1560 if (ShapeUtil::ElementsIn(output_shape) !=
1561 (launch_dimensions.launch_bound() * launch_config.unroll_factor)) {
1562 return false;
1563 }
1564 if (launch_dimensions.thread_counts_per_block().y > 1 ||
1565 launch_dimensions.thread_counts_per_block().z > 1) {
1566 return false;
1567 }
1568 // Safety checks finish.
1569
1570 // Profitability checks.
1571 // TODO(b/228209668) Investigate alternative profitability heuristics.
1572 // We should have a single output and not in row-major layout.
1573 if (LayoutUtil::IsMonotonicWithDim0Major(output_shape.layout())) {
1574 return false;
1575 }
1576 // We should have multiple inputs and all of them should have row-major
1577 // layout.
1578 if (operand_arrays.size() <= 1) return false;
1579 for (const llvm_ir::IrArray& input : operand_arrays) {
1580 const Shape& input_shape = input.GetShape();
1581 if (input_shape.is_dynamic()) return false;
1582 if (!LayoutUtil::IsMonotonicWithDim0Major(input_shape.layout())) {
1583 return false;
1584 }
1585 }
1586 return true;
1587 }
1588
EmitLaunchFunc(mlir::Operation * op)1589 Status IrEmitterUnnested::EmitLaunchFunc(mlir::Operation* op) {
1590 auto launch_func = mlir::cast<mlir::gpu::LaunchFuncOp>(op);
1591 auto kernel_func =
1592 mlir::SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
1593 launch_func, launch_func.kernel());
1594 if (!kernel_func) {
1595 return InternalError("kernel '%s' not found",
1596 launch_func.getKernelName().str());
1597 }
1598
1599 // Lower kernel module to NVVM.
1600 auto gpu_module = kernel_func->getParentOfType<mlir::gpu::GPUModuleOp>();
1601 std::unique_ptr<llvm::Module> llvm_module = mlir::translateModuleToLLVMIR(
1602 gpu_module, module_->getContext(), gpu_module.getName());
1603 if (!llvm_module)
1604 return InternalError("Failed to translate GPU module to LLVM");
1605
1606 // Add kernel to LLVM module.
1607 llvm_module->setDataLayout(module_->getDataLayout());
1608 llvm::Linker::linkModules(*module_, std::move(llvm_module));
1609
1610 // Retrieve launch dimensions from arith.constant ops.
1611 auto get_dim3d = [](mlir::gpu::KernelDim3 dim3) {
1612 auto get_const = [](mlir::Value value) -> int64_t {
1613 auto const_op = value.getDefiningOp<mlir::arith::ConstantOp>();
1614 if (!const_op) return -1;
1615 auto attr = const_op.getValue().cast<mlir::IntegerAttr>();
1616 if (!attr) return -1;
1617 return attr.getValue().getSExtValue();
1618 };
1619 return LaunchDimensions::Dim3D{get_const(dim3.x), get_const(dim3.y),
1620 get_const(dim3.z)};
1621 };
1622 LaunchDimensions launch_dimensions(
1623 get_dim3d(launch_func.getGridSizeOperandValues()),
1624 get_dim3d(launch_func.getBlockSizeOperandValues()));
1625
1626 // Create BufferSlice array from launch_func arguments, using the
1627 // attribute depicting which arguments are written by the kernel.
1628 std::vector<BufferSlice> slices;
1629 unsigned num_kernel_operands = launch_func.getNumKernelOperands();
1630 slices.reserve(num_kernel_operands);
1631 mlir::ArrayRef<mlir::Attribute> written_operands =
1632 mlir::getWrittenOperandsAttribute(launch_func).getValue();
1633 for (const auto& [operand, written] :
1634 llvm::zip_first(launch_func.operands(),
1635 written_operands.take_back(num_kernel_operands))) {
1636 BufferSlice slice;
1637 TF_ASSIGN_OR_RETURN(slice.buffer_slice,
1638 GetAllocationSlice(operand, &slice.constant_name));
1639 slice.shape = GetShape(operand);
1640 slice.written = written.cast<mlir::BoolAttr>().getValue();
1641 slices.push_back(std::move(slice));
1642 }
1643
1644 // Add kernel prototype to module_, kernel thunk to thunk_sequence_.
1645 std::string kernel_name = GetIrNameFromLoc(launch_func.getLoc());
1646 std::vector<llvm_ir::IrArray> ir_arrays;
1647 TF_ASSIGN_OR_RETURN(
1648 std::unique_ptr<Thunk> kernel_thunk,
1649 BuildKernelThunkImpl(kernel_name, GetThunkInfo(op), slices, &ir_arrays,
1650 launch_dimensions));
1651 thunk_sequence_.emplace_back(std::move(kernel_thunk));
1652
1653 // Move function body into kernel prototype.
1654 llvm::Function* prototype_func = b_.GetInsertBlock()->getParent();
1655 llvm::Function* implementation_func =
1656 module_->getFunction(kernel_func.getName());
1657 prototype_func->getBasicBlockList().splice(
1658 prototype_func->end(), implementation_func->getBasicBlockList());
1659 for (const auto& [arg, ir_array] :
1660 llvm::zip_first(implementation_func->args(), ir_arrays)) {
1661 arg.replaceAllUsesWith(ir_array.GetBasePointer());
1662 }
1663 implementation_func->eraseFromParent();
1664
1665 // Replace pre-existing return with unconditional branch to next block.
1666 llvm::Instruction* terminator =
1667 prototype_func->getEntryBlock().getTerminator();
1668 llvm::BranchInst::Create(&*std::next(prototype_func->begin()), terminator);
1669 terminator->eraseFromParent();
1670
1671 return Status::OK();
1672 }
1673
1674 // TODO(timshen): update the comment once the HandleFusion code path deleted.
1675 //
1676 // This is migrated from IrEmitter::HandleFusion() with IrEmitterUnnested as the
1677 // subclass. The logic is de-virtualized and less scattered.
EmitLoopFusion(mlir::Operation * op)1678 Status IrEmitterUnnested::EmitLoopFusion(mlir::Operation* op) {
1679 auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(op);
1680
1681 TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation,
1682 GetOrCreateSubComputationFromRegion(&fusion.getRegion(),
1683 /*is_fusion=*/true));
1684
1685 int unroll_factor;
1686 if (!MayPreventVectorization(fusion)) {
1687 unroll_factor = ComputeMaxUnrollFactor(fusion, hlo_module_config_);
1688 } else {
1689 unroll_factor = 1;
1690 }
1691
1692 bool row_vectorized;
1693 int num_big_inputs;
1694 std::tie(row_vectorized, num_big_inputs) = RowVectorizationEnabled(fusion);
1695 bool few_waves = [fusion, row_vectorized, num_big_inputs]() mutable {
1696 for (mlir::Operation& op : fusion.getRegion().front()) {
1697 if (mlir::isa<mlir::bufferization::ToTensorOp,
1698 mlir::memref::TensorStoreOp, mlir::lmhlo::TerminatorOp,
1699 mlir::mhlo::ReturnOp, mlir::mhlo::ConstantOp>(op)) {
1700 continue;
1701 }
1702 HloOpcode opcode = *MhloToHloOpcode(&op);
1703 if (HloInstruction::IsOpElementwise(opcode)) {
1704 continue;
1705 }
1706 if (auto broadcast = mlir::dyn_cast<mlir::mhlo::BroadcastInDimOp>(op)) {
1707 if (broadcast.broadcast_dimensions().empty() ||
1708 // More then 2 bit inputs cause one speed regression.
1709 (row_vectorized && num_big_inputs <= 3)) {
1710 continue;
1711 }
1712 }
1713 VLOG(2) << "few_waves not enabled due to: " << MlirToString(&op);
1714 return false;
1715 }
1716 return true;
1717 }();
1718
1719 Shape element_shape = GetShape(fusion.getOutputBuffers()[0]);
1720 LaunchDimensionsConfig launch_config{unroll_factor, few_waves,
1721 row_vectorized};
1722 // Check that the shapes is supported.
1723 if (launch_config.row_vectorized &&
1724 ThreadsPerBlockRowVectorized(element_shape,
1725 ir_emitter_context_->gpu_device_info(),
1726 launch_config) <= 0) {
1727 VLOG(2) << "Cancelling row_vectorization as the shape isn't supported.";
1728 launch_config.row_vectorized = false;
1729 launch_config.few_waves = false;
1730 }
1731
1732 TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
1733 CalculateLaunchDimensions(
1734 element_shape, ir_emitter_context_->gpu_device_info(),
1735 launch_config));
1736
1737 std::vector<llvm_ir::IrArray> ir_arrays;
1738 Thunk* kernel_thunk;
1739 {
1740 TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> kernel_thunk_ptr,
1741 BuildKernelThunk(fusion, GetThunkInfo(op), &ir_arrays,
1742 launch_dimensions));
1743 kernel_thunk = kernel_thunk_ptr.get();
1744 thunk_sequence_.emplace_back(std::move(kernel_thunk_ptr));
1745 }
1746
1747 absl::Span<llvm_ir::IrArray> operand_arrays =
1748 absl::MakeSpan(ir_arrays).subspan(0, fusion.getInputBuffers().size());
1749 absl::Span<llvm_ir::IrArray> output_element_arrays =
1750 absl::MakeSpan(ir_arrays).subspan(fusion.getInputBuffers().size(),
1751 fusion.getOutputBuffers().size());
1752
1753 GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
1754 GetNestedComputer());
1755 FusedIrEmitter fused_emitter(elemental_emitter);
1756
1757 for (int i = 0; i < fusion.getInputBuffers().size(); i++) {
1758 auto* builder = &b_;
1759 auto ir_array = operand_arrays[i];
1760 fused_emitter.BindGenerator(
1761 *fused_computation->parameter_instruction(i),
1762 [builder, ir_array](llvm_ir::IrArray::Index index) {
1763 return ir_array.EmitReadArrayElement(index, builder);
1764 });
1765 }
1766 launch_config.logical_order = EnableLogicalIndexGenerationForOutput(
1767 launch_config, launch_dimensions, operand_arrays, output_element_arrays);
1768 TF_ASSIGN_OR_RETURN(
1769 auto element_generator,
1770 fused_emitter.GetGenerator(*fused_computation->root_instruction()));
1771
1772 llvm::Type* index_type =
1773 GetIndexTypeForKernel(fusion, launch_dimensions.launch_bound(), &b_);
1774
1775 TF_RETURN_IF_ERROR(
1776 ParallelLoopEmitter(element_generator, output_element_arrays,
1777 launch_dimensions, &b_, launch_config)
1778 .EmitLoop(GetIrNameFromLoc(fusion->getLoc()), index_type));
1779
1780 b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
1781 return OkStatus();
1782 }
1783
1784 // Returns whether any of the rooots of the fusion are unnested reductions.
HasAnyUnnestedReductionRoot(mlir::lmhlo::FusionOp fusion)1785 static bool HasAnyUnnestedReductionRoot(mlir::lmhlo::FusionOp fusion) {
1786 return absl::c_any_of(fusion.getFusionRoots(), [&](mlir::Operation* op) {
1787 return IsReductionFromOrToContiguousDimensions(op);
1788 });
1789 }
1790
EmitFusion(mlir::Operation * op)1791 Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) {
1792 auto fusion_op = mlir::cast<mlir::lmhlo::FusionOp>(op);
1793 TF_ASSIGN_OR_RETURN(
1794 const HloComputation* fused_computation,
1795 GetOrCreateSubComputationFromRegion(&fusion_op.getRegion(),
1796 /*is_fusion=*/true));
1797
1798 if (HasAnyUnnestedReductionRoot(fusion_op)) {
1799 return EmitUnnestedReduction(fusion_op);
1800 }
1801
1802 auto fusion_results = fusion_op.getFusionResults();
1803 TF_RET_CHECK(!fusion_results.empty());
1804 if (fusion_results.size() > 1) {
1805 // In the case of root tuple, it can be either reduce or slice input
1806 // fusion.
1807 if (IsInputFusibleSlices(op, /*verify_no_strides=*/true)) {
1808 // The emitter doesn't support all cases. If it's not supported, fallback
1809 // to ElementalIrEmitter.
1810 auto status = EmitInputFusibleNonStridedSlices(op);
1811 if (status.code() == tensorflow::error::FAILED_PRECONDITION) {
1812 return EmitLoopFusion(op);
1813 }
1814 return status;
1815 }
1816 }
1817
1818 mlir::Operation* fusion_root = fusion_results[0].getDefiningOp();
1819 if (mlir::isa<mlir::mhlo::ScatterOp>(fusion_root)) {
1820 return EmitScatter(fusion_op, fused_computation);
1821 }
1822
1823 if (!IsSingleInstructionFusion(fusion_op) &&
1824 CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
1825 fusion_op, ir_emitter_context_->allocations())) {
1826 return EmitDynamicUpdateSlice(fusion_op, fused_computation);
1827 }
1828
1829 if (auto copy = mlir::dyn_cast<mlir::mhlo::CopyOp>(fusion_root);
1830 copy && IsSingleInstructionFusion(fusion_op)) {
1831 auto operands = GetHloOperands(fusion_op);
1832 auto outputs = GetHloOutputs(fusion_op);
1833 TF_RET_CHECK(operands.size() == 1);
1834 TF_RET_CHECK(outputs.size() == 1);
1835 auto operand_shape = GetShape(operands[0]);
1836 auto output_shape = GetShape(outputs[0]);
1837
1838 CHECK(ShapeUtil::Compatible(operand_shape, output_shape));
1839 auto maybe_slice = GetAllocationSlice(operands[0]);
1840 if (LayoutUtil::Equal(operand_shape.layout(), output_shape.layout()) &&
1841 maybe_slice.ok()) {
1842 // Copy the operand into the output if it's not the same buffer already.
1843 auto operand_buffer = *maybe_slice;
1844 auto destination_buffer = *GetAllocationSlice(outputs[0]);
1845 if (operand_buffer != destination_buffer) {
1846 AddThunkToThunkSequence(std::make_unique<DeviceToDeviceCopyThunk>(
1847 GetThunkInfo(op),
1848 /*source_address=*/operand_buffer,
1849 /*destination_buffer=*/destination_buffer,
1850 /*mem_size=*/
1851 ByteSizeOf(operand_shape)));
1852 }
1853 return OkStatus();
1854 }
1855 }
1856
1857 if (std::optional<TransposeDimsAndParams> descr =
1858 Match021Transpose(fused_computation)) {
1859 return Emit021Transpose(*descr, fusion_op);
1860 }
1861
1862 return EmitLoopFusion(op);
1863 }
1864
EmitExtraOutputsForReduce(const ReductionOutputMap & result_ir_arrays,const IrArray::Index & index,const ReductionCodegenInfo & reduction_info,const ExtraOutputGensMap & extra_output_gens)1865 Status IrEmitterUnnested::EmitExtraOutputsForReduce(
1866 const ReductionOutputMap& result_ir_arrays, const IrArray::Index& index,
1867 const ReductionCodegenInfo& reduction_info,
1868 const ExtraOutputGensMap& extra_output_gens) {
1869 // Compute all extra output values before writing them. This avoids
1870 // overwriting aliased input/output buffers before all reads occurred.
1871 absl::flat_hash_map<const HloInstruction*, llvm::Value*>
1872 extra_output_ir_values;
1873 for (const auto& p : extra_output_gens) {
1874 TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
1875 p.second(index));
1876 extra_output_ir_values[p.first] = extra_output_ir_value;
1877 }
1878 for (const auto& p : extra_output_ir_values) {
1879 absl::Span<llvm_ir::IrArray const> result_ir = result_ir_arrays.at(p.first);
1880 CHECK_EQ(result_ir.size(), 1);
1881 result_ir[0].EmitWriteArrayElement(
1882 index, p.second, &b_, /*use_linear_index=*/
1883 reduction_info.GetNumPartialResults() == 1);
1884 }
1885 return OkStatus();
1886 }
1887
AssertNonDeterminismIsOkay(const std::string & op_name)1888 Status IrEmitterUnnested::AssertNonDeterminismIsOkay(
1889 const std::string& op_name) {
1890 if (hlo_module_config_.debug_options().xla_gpu_deterministic_ops()) {
1891 return Unimplemented(
1892 "HLO instruction %s does not have a deterministic implementation, "
1893 "but run-to-run determinism is required by "
1894 "--xla_gpu_deterministic_ops.",
1895 op_name);
1896 }
1897 return OkStatus();
1898 }
1899
EmitSelectAndScatter(mlir::Operation * op)1900 Status IrEmitterUnnested::EmitSelectAndScatter(mlir::Operation* op) {
1901 auto select_and_scatter_op = mlir::cast<mlir::lmhlo::SelectAndScatterOp>(op);
1902
1903 const Shape source_shape = GetShape(select_and_scatter_op.getSource());
1904 const Shape operand_shape = GetShape(select_and_scatter_op.getOperand());
1905 const int64_t rank = operand_shape.rank();
1906
1907 CHECK_EQ(rank, source_shape.rank());
1908 if (select_and_scatter_op.getWindowDimensions()) {
1909 CHECK_EQ(rank, select_and_scatter_op.getWindowDimensions()->size());
1910 }
1911
1912 TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay(
1913 mlir::GetNameFromLoc(select_and_scatter_op.getLoc())));
1914
1915 std::string name = GetIrNameFromLoc(select_and_scatter_op.getLoc());
1916
1917 // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk
1918 // consisting of two thunks, an initializer KernelThunk that initializes
1919 // the output and another KernelThunk that accumulates the scattered
1920 // elements.
1921 ThunkSequence thunks;
1922 thunks.emplace_back();
1923 TF_ASSIGN_OR_RETURN(
1924 thunks.back(),
1925 BuildInitializerThunk(op, select_and_scatter_op.getInitValue(),
1926 select_and_scatter_op.getOut()));
1927
1928 TF_ASSIGN_OR_RETURN(
1929 LaunchDimensions launch_dimensions,
1930 CalculateLaunchDimensions(source_shape,
1931 ir_emitter_context_->gpu_device_info()));
1932 std::vector<llvm_ir::IrArray> ir_arrays;
1933 thunks.emplace_back();
1934 // Init value is not needed in IR emission.
1935 TF_ASSIGN_OR_RETURN(
1936 thunks.back(),
1937 BuildKernelThunk(
1938 select_and_scatter_op,
1939 {select_and_scatter_op.getOperand(),
1940 select_and_scatter_op.getSource(), select_and_scatter_op.getOut()},
1941 Thunk::ThunkInfo(), &ir_arrays, launch_dimensions));
1942
1943 CHECK_EQ(ir_arrays.size(), 3);
1944 const IrArray& operand_array = ir_arrays[0];
1945 const IrArray& source_array = ir_arrays[1];
1946 const IrArray& out_array = ir_arrays[2];
1947
1948 auto select_and_scatter_thunk =
1949 std::make_unique<SequentialThunk>(GetThunkInfo(op), std::move(thunks));
1950
1951 llvm::Type* index_type = GetIndexTypeForKernel(
1952 select_and_scatter_op, launch_dimensions.launch_bound(), &b_);
1953 auto index_typed_constant = [&](uint64_t c) -> llvm::Constant* {
1954 return llvm::ConstantInt::get(index_type, c);
1955 };
1956
1957 // kSelectAndScatter is implemented as two kernel launches: the first launch
1958 // initializes the output array to the given initial value,
1959 // and the second accumulates the "source" matrix to the
1960 // selected elements in the output array. The first launch is already
1961 // implemented by the initializer thunk generated earlier, so this function
1962 // only needs to take care of the select-and-scatter part.
1963 //
1964 // Pseudo code for select-and-scatter:
1965 //
1966 // for (coordinates S in the source): # This loop is parallel.
1967 // initialized_flag = false
1968 // for (coordinates W in the window):
1969 // I = S * stride + W - pad_low
1970 // if I within bounds of operand:
1971 // if !(initialized_flag and select(selected_value, operand(I))):
1972 // selected_value = operand(I)
1973 // selected_index = I
1974 // initialized_flag = true
1975 // if initialized_flag:
1976 // output(selected_index) = scatter(output(selected_index), source(S))
1977 auto loop_body_emitter = [&](const IrArray::Index& source_index) -> Status {
1978 // Allocate space to keep the currently selected value, its index, and a
1979 // boolean flag if the value is initialized. The initialized_flag is set
1980 // false.
1981 llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
1982 llvm_ir::PrimitiveTypeToIrType(operand_shape.element_type(), module_),
1983 "selected_value_address", &b_);
1984
1985 llvm::AllocaInst* selected_index_address =
1986 llvm_ir::EmitAllocaAtFunctionEntryWithCount(
1987 index_type, index_typed_constant(rank), "selected_index_address",
1988 &b_);
1989
1990 llvm::AllocaInst* initialized_flag_address =
1991 llvm_ir::EmitAllocaAtFunctionEntry(b_.getInt1Ty(),
1992 "initialized_flag_address", &b_);
1993 Store(b_.getInt1(false), initialized_flag_address);
1994
1995 // Create the inner loop to iterate over the window.
1996 llvm_ir::ForLoopNest window_loops(absl::StrCat(name, "inner"), &b_,
1997 index_type);
1998
1999 DimensionVector window_size;
2000 mlir::DenseIntElementsAttr window_dimensions =
2001 select_and_scatter_op.getWindowDimensions().getValue();
2002 for (const auto& dim : window_dimensions) {
2003 window_size.push_back(dim.getSExtValue());
2004 CHECK_GT(dim.getSExtValue(), 0);
2005 }
2006
2007 const IrArray::Index window_index = window_loops.AddLoopsForShape(
2008 ShapeUtil::MakeShape(operand_shape.element_type(), window_size),
2009 "window");
2010 llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(),
2011 &b_);
2012
2013 // Compute the operand index to visit and evaluate the condition whether the
2014 // operand index is within the bounds. The unsigned comparison includes
2015 // checking whether the operand index >= 0.
2016 std::vector<llvm::Value*> operand_multi_index(source_index.size());
2017 llvm::Value* in_bounds_condition = b_.getInt1(true);
2018
2019 auto strides = *select_and_scatter_op.getWindowStrides();
2020 auto paddings = *select_and_scatter_op.getPadding();
2021
2022 for (auto stride_and_padding :
2023 llvm::enumerate(llvm::zip(strides, paddings))) {
2024 const int i = stride_and_padding.index();
2025 int64_t stride = std::get<0>(stride_and_padding.value()).getSExtValue();
2026 int64_t padding = std::get<1>(stride_and_padding.value()).getSExtValue();
2027
2028 llvm::Value* strided_index =
2029 NSWMul(source_index[i], index_typed_constant(stride));
2030 operand_multi_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]),
2031 index_typed_constant(padding));
2032 llvm::Value* index_condition = ICmpULT(
2033 operand_multi_index[i],
2034 index_typed_constant(ShapeUtil::GetDimension(operand_shape, i)));
2035 in_bounds_condition = And(in_bounds_condition, index_condition);
2036 }
2037
2038 // Only need to do something if the operand index is within the bounds.
2039 // First check if the initialized_flag is set.
2040 llvm_ir::LlvmIfData if_in_bounds =
2041 llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
2042 llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
2043 llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
2044 Load(initialized_flag_address->getAllocatedType(),
2045 initialized_flag_address),
2046 "initialized", &b_);
2047
2048 // If the initialized_flag is false, initialize the selected value and index
2049 // with the currently visiting operand.
2050 llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &b_);
2051 const auto save_operand_index = [&](const IrArray::Index& operand_index) {
2052 for (int64_t i = 0; i < rank; ++i) {
2053 llvm::Value* selected_index_address_slot =
2054 InBoundsGEP(selected_index_address->getAllocatedType(),
2055 selected_index_address, {b_.getInt32(i)});
2056 Store(operand_index[i], selected_index_address_slot);
2057 }
2058 };
2059 IrArray::Index operand_index(operand_multi_index, operand_shape,
2060 index_type);
2061 llvm::Value* operand_data =
2062 operand_array.EmitReadArrayElement(operand_index, &b_);
2063 Store(operand_data, selected_value_address);
2064 save_operand_index(operand_index);
2065 Store(b_.getInt1(true), initialized_flag_address);
2066
2067 // If the initialized_flag is true, call the `select` function to
2068 // potentially update the selected value and index with the currently
2069 // visiting operand.
2070 llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &b_);
2071 llvm::Value* operand_address =
2072 operand_array.EmitArrayElementAddress(operand_index, &b_);
2073 llvm::AllocaInst* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
2074 llvm_ir::PrimitiveTypeToIrType(PRED, module_), "select_return_buffer",
2075 &b_);
2076
2077 TF_ASSIGN_OR_RETURN(
2078 const HloComputation* select_computation,
2079 GetOrCreateSubComputationFromRegion(&select_and_scatter_op.getSelect(),
2080 /*is_fusion=*/false));
2081
2082 TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
2083 *select_computation, {selected_value_address, operand_address},
2084 select_return_buffer));
2085 llvm::Value* result =
2086 Load(select_return_buffer->getAllocatedType(), select_return_buffer);
2087
2088 // If the 'select' function returns false, update the selected value and the
2089 // index to the currently visiting operand.
2090 llvm::Value* cond =
2091 ICmpNE(result,
2092 llvm::ConstantInt::get(
2093 llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
2094 "boolean_predicate");
2095 llvm_ir::LlvmIfData if_select_lhs =
2096 llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
2097 llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
2098 Store(Load(operand_array.GetElementLlvmType(), operand_address),
2099 selected_value_address);
2100 save_operand_index(operand_index);
2101
2102 // If the initialized_flag is true, write to the selected index of the
2103 // output; otherwise the window is outside the source (in the padding) and
2104 // should be ignored.
2105 llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(),
2106 &b_);
2107 llvm_ir::LlvmIfData if_should_store = llvm_ir::EmitIfThenElse(
2108 Load(initialized_flag_address->getAllocatedType(),
2109 initialized_flag_address),
2110 "should-store", &b_, /*emit_else=*/false);
2111 llvm_ir::SetToFirstInsertPoint(if_should_store.true_block, &b_);
2112
2113 // After iterating over the window elements, scatter the source element to
2114 // the selected index of the output. The value we store at the output
2115 // location is computed by calling the `scatter` function with the source
2116 // value and the current output value.
2117 std::vector<llvm::Value*> selected_multi_index;
2118 for (int64_t i = 0; i < rank; ++i) {
2119 llvm::Value* selected_index_address_slot =
2120 InBoundsGEP(selected_index_address->getAllocatedType(),
2121 selected_index_address, {b_.getInt32(i)});
2122 selected_multi_index.push_back(
2123 Load(selected_index_address->getAllocatedType(),
2124 selected_index_address_slot));
2125 }
2126 const Shape output_shape = GetShape(select_and_scatter_op.getOut());
2127 llvm::Value* source_value_address =
2128 source_array.EmitArrayElementAddress(source_index, &b_);
2129 IrArray::Index selected_index(selected_multi_index, output_shape,
2130 operand_index.GetType());
2131 llvm::Value* output_value_address =
2132 out_array.EmitArrayElementAddress(selected_index, &b_);
2133
2134 TF_ASSIGN_OR_RETURN(
2135 const HloComputation* scatter_computation,
2136 GetOrCreateSubComputationFromRegion(&select_and_scatter_op.getScatter(),
2137 /*is_fusion=*/false));
2138
2139 return EmitAtomicOperationForNestedComputation(
2140 *scatter_computation, output_value_address, source_value_address,
2141 source_array.GetElementLlvmType());
2142 };
2143
2144 AddThunkToThunkSequence(std::move(select_and_scatter_thunk));
2145 return ParallelLoopEmitter(loop_body_emitter, source_shape, launch_dimensions,
2146 &b_)
2147 .EmitLoop(name, index_type);
2148 }
2149
EmitWhile(mlir::Operation * op)2150 Status IrEmitterUnnested::EmitWhile(mlir::Operation* op) {
2151 auto while_op = mlir::cast<mlir::lmhlo::WhileOp>(op);
2152
2153 auto cond_result = GetHloOutputs(while_op);
2154 TF_RET_CHECK(cond_result.size() == 1);
2155 TF_RET_CHECK(cond_result[0]
2156 .getType()
2157 .cast<mlir::ShapedType>()
2158 .getElementType()
2159 .isInteger(/*width=*/1))
2160 << "While condition computation must return bool";
2161
2162 // Build ForThunk for conformant while loops, otherwise build WhileThunk.
2163 if (while_op.getTripCount()) {
2164 TF_ASSIGN_OR_RETURN(auto thunk, BuildForThunk(while_op, GetThunkInfo(op),
2165 *while_op.getTripCount()));
2166 AddThunkToThunkSequence(std::move(thunk));
2167 } else {
2168 TF_ASSIGN_OR_RETURN(auto thunk,
2169 BuildWhileThunk(while_op, GetThunkInfo(op)));
2170 AddThunkToThunkSequence(std::move(thunk));
2171 }
2172 return OkStatus();
2173 }
2174
EmitRngGetAndUpdateState(mlir::Operation * op)2175 Status IrEmitterUnnested::EmitRngGetAndUpdateState(mlir::Operation* op) {
2176 auto rng_op = mlir::dyn_cast<mlir::lmhlo::RngGetAndUpdateStateOp>(op);
2177
2178 // Emit a kernel to increment the global state for Philox RNG algorithm.
2179 std::vector<llvm_ir::IrArray> ir_arrays;
2180 TF_ASSIGN_OR_RETURN(
2181 auto kernel_thunk,
2182 BuildKernelThunk(rng_op, rng_op.getState(), GetThunkInfo(op), &ir_arrays,
2183 LaunchDimensions()));
2184 AddThunkToThunkSequence(std::move(kernel_thunk));
2185
2186 llvm::Value* old_state =
2187 llvm_ir::RngGetAndUpdateState(rng_op.getDelta(), module_, &b_);
2188
2189 const Shape shape = GetShape(rng_op.getState());
2190
2191 llvm::Value* output_address = ir_arrays[0].EmitArrayElementAddress(
2192 llvm_ir::IrArray::Index(
2193 /*linear=*/b_.getInt64(0), shape, &b_),
2194 &b_, "rng_state_address");
2195 output_address = BitCast(
2196 output_address, llvm::PointerType::get(
2197 old_state->getType(),
2198 output_address->getType()->getPointerAddressSpace()));
2199 Store(old_state, output_address);
2200
2201 return OkStatus();
2202 }
2203
EmitScatter(mlir::Operation * op)2204 Status IrEmitterUnnested::EmitScatter(mlir::Operation* op) {
2205 ThunkSequence thunks;
2206
2207 auto scatter_op = mlir::cast<mlir::lmhlo::ScatterOp>(op);
2208
2209 if (!scatter_op.getUniqueIndices()) {
2210 TF_RETURN_IF_ERROR(
2211 AssertNonDeterminismIsOkay(GetIrNameFromLoc(scatter_op.getLoc())));
2212 }
2213
2214 TF_ASSIGN_OR_RETURN(auto operand_buffer,
2215 GetAllocationSlice(scatter_op.getOperand()));
2216 TF_ASSIGN_OR_RETURN(auto output_buffer,
2217 GetAllocationSlice(scatter_op.getOutput()));
2218
2219 // Copy the operand into the output if it's not the same buffer already.
2220 if (operand_buffer != output_buffer) {
2221 thunks.push_back(std::make_unique<DeviceToDeviceCopyThunk>(
2222 Thunk::ThunkInfo(),
2223 /*source_address=*/operand_buffer,
2224 /*destination_buffer=*/output_buffer,
2225 /*mem_size=*/
2226 ShapeUtil::ByteSizeOf(GetShape(scatter_op.getOutput()))));
2227 }
2228
2229 const Shape& data_shape = GetShape(scatter_op.getUpdates());
2230 TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
2231 CalculateLaunchDimensions(
2232 data_shape, ir_emitter_context_->gpu_device_info()));
2233
2234 // Create kernel thunk for all operands except the first one (`operand`). The
2235 // code generated for scatter below assumes that the input operand is already
2236 // copied into the output, so does not use it in codegen.
2237 std::vector<llvm_ir::IrArray> ir_arrays;
2238 thunks.emplace_back();
2239 TF_ASSIGN_OR_RETURN(
2240 thunks.back(),
2241 BuildKernelThunk(scatter_op, scatter_op.getOperands().drop_front(),
2242 GetThunkInfo(op), &ir_arrays, launch_dimensions));
2243
2244 CHECK_EQ(ir_arrays.size(), 3);
2245 const IrArray& scatter_indices = ir_arrays[0];
2246 const IrArray& updates = ir_arrays[1];
2247 const IrArray& output = ir_arrays[2];
2248
2249 auto get_index_type = [&](int64_t launch_size) {
2250 return GetIndexTypeForKernel(scatter_op, launch_size, &b_);
2251 };
2252
2253 TF_RETURN_IF_ERROR(EmitScatter(
2254 thunks.back().get(), scatter_op, launch_dimensions, output,
2255 /*scatter_indices_gen=*/
2256 [&](const IrArray::Index& index) {
2257 return scatter_indices.EmitReadArrayElement(index, &b_,
2258 "scatter_index");
2259 },
2260 /*updates_gen=*/
2261 [&](const IrArray::Index& index) {
2262 return updates.EmitReadArrayElement(index, &b_, "update");
2263 },
2264 /* get_index_type=*/
2265 get_index_type));
2266
2267 // Elide the sequential thunk if there's no copy.
2268 if (thunks.size() == 1) {
2269 AddThunkToThunkSequence(std::move(thunks[0]));
2270 } else {
2271 AddThunkToThunkSequence(
2272 std::make_unique<SequentialThunk>(GetThunkInfo(op), std::move(thunks)));
2273 }
2274
2275 return OkStatus();
2276 }
2277
EmitScatter(Thunk * thunk,mlir::lmhlo::ScatterOp scatter,const LaunchDimensions & launch_dimensions,const llvm_ir::IrArray & output,const llvm_ir::ElementGenerator & scatter_indices_gen,const llvm_ir::ElementGenerator & updates_gen,std::function<llvm::Type * (int64_t)> get_index_type)2278 Status IrEmitterUnnested::EmitScatter(
2279 Thunk* thunk, mlir::lmhlo::ScatterOp scatter,
2280 const LaunchDimensions& launch_dimensions, const llvm_ir::IrArray& output,
2281 const llvm_ir::ElementGenerator& scatter_indices_gen,
2282 const llvm_ir::ElementGenerator& updates_gen,
2283 std::function<llvm::Type*(int64_t)> get_index_type) {
2284 const Shape operand_shape = GetShape(scatter.getOperand());
2285 CHECK(ShapeUtil::Equal(GetShape(scatter.getOutput()), operand_shape));
2286
2287 TF_ASSIGN_OR_RETURN(
2288 const HloComputation* update_computation,
2289 GetOrCreateSubComputationFromRegion(&scatter.getUpdateComputation(),
2290 /*is_fusion=*/false));
2291
2292 ScatterDescriptor desc;
2293 desc.name = GetIrNameFromLoc(scatter.getLoc());
2294 desc.operand_shape = operand_shape;
2295 desc.scatter_indices_shape = GetShape(scatter.getScatterIndices());
2296 desc.updates_shape = GetShape(scatter.getUpdates());
2297 desc.dim_numbers = scatter.getScatterDimensionNumbers();
2298 desc.unique_indices = scatter.getUniqueIndices();
2299 desc.update_computation = update_computation;
2300 desc.output = output;
2301 desc.scatter_indices_gen = scatter_indices_gen;
2302 desc.updates_gen = updates_gen;
2303 desc.get_index_type = get_index_type;
2304 return EmitScatter(desc, thunk, launch_dimensions);
2305 }
2306
EmitScatter(const ScatterDescriptor & desc,Thunk * thunk,const LaunchDimensions & launch_dimensions)2307 Status IrEmitterUnnested::EmitScatter(
2308 const ScatterDescriptor& desc, Thunk* thunk,
2309 const LaunchDimensions& launch_dimensions) {
2310 if (!desc.unique_indices) {
2311 TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay(desc.name));
2312 }
2313 auto loop_body_emitter = [&](const IrArray::Index& index) -> Status {
2314 std::vector<llvm::Value*> raw_window_multidim;
2315 std::vector<llvm::Value*> input_scatter_multidim;
2316 std::vector<int64_t> raw_window_bounds;
2317
2318 // Partition the index into window indices and scatter indices.
2319 for (int64_t i = 0, e = index.size(); i != e; ++i) {
2320 // For window indices also remember the window size, this comes in handy
2321 // later.
2322 if (llvm::is_contained(desc.dim_numbers.getUpdateWindowDims(), i)) {
2323 raw_window_multidim.push_back(index[i]);
2324 raw_window_bounds.push_back(desc.updates_shape.dimensions(i));
2325 } else {
2326 input_scatter_multidim.push_back(index[i]);
2327 }
2328 }
2329 DCHECK_EQ(raw_window_multidim.size(),
2330 desc.dim_numbers.getUpdateWindowDims().size());
2331
2332 // Apply inserted_window_dims to the window dimensions.
2333 int64_t raw_window_multidim_idx = 0;
2334 llvm::SmallVector<llvm::Value*> input_window_multidim;
2335 llvm::SmallVector<int64_t> input_window_bounds;
2336 const int64_t rank = desc.operand_shape.rank();
2337 input_window_bounds.reserve(rank);
2338 input_window_multidim.reserve(rank);
2339
2340 for (int64_t i = 0; i != rank; ++i) {
2341 if (llvm::is_contained(desc.dim_numbers.getInsertedWindowDims(), i)) {
2342 input_window_bounds.push_back(1); // Trivial dimension.
2343 input_window_multidim.push_back(index.GetConstantWithIndexType(0));
2344 } else {
2345 input_window_bounds.push_back(
2346 raw_window_bounds[raw_window_multidim_idx]);
2347 input_window_multidim.push_back(
2348 raw_window_multidim[raw_window_multidim_idx]);
2349 ++raw_window_multidim_idx;
2350 }
2351 }
2352 DCHECK_EQ(input_window_multidim.size(), desc.operand_shape.rank());
2353
2354 // Insert a 1 dimension at the end if index_vector_dim requests one.
2355 Shape scatter_indices_shape_fixed = desc.scatter_indices_shape;
2356 if (desc.dim_numbers.getIndexVectorDim() ==
2357 desc.scatter_indices_shape.rank()) {
2358 scatter_indices_shape_fixed.add_dimensions(1);
2359 scatter_indices_shape_fixed.mutable_layout()->add_minor_to_major(
2360 desc.dim_numbers.getIndexVectorDim());
2361 }
2362
2363 // Now load the indices corresponding to the current window from
2364 // scatter_indices.
2365 std::vector<llvm::Value*> raw_scatter_index_multidim =
2366 input_scatter_multidim;
2367 raw_scatter_index_multidim.insert(raw_scatter_index_multidim.begin() +
2368 desc.dim_numbers.getIndexVectorDim(),
2369 nullptr);
2370 llvm::Value* is_in_bounds = b_.getTrue();
2371 for (int64_t i = 0,
2372 e = desc.dim_numbers.getScatterDimsToOperandDims().size();
2373 i != e; ++i) {
2374 // Our index is stored along index_vector_dim, insert that into the lookup
2375 // index into scatter_indices.
2376 raw_scatter_index_multidim[desc.dim_numbers.getIndexVectorDim()] =
2377 index.GetConstantWithIndexType(i);
2378 llvm_ir::IrArray::Index raw_scatter_index_index(
2379 raw_scatter_index_multidim, scatter_indices_shape_fixed,
2380 index.GetType());
2381
2382 int64_t operand_dim = desc.dim_numbers.getScatterDimsToOperandDims()[i];
2383 TF_ASSIGN_OR_RETURN(
2384 llvm::Value* const loaded_scatter_index,
2385 desc.scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape(
2386 scatter_indices_shape_fixed, desc.scatter_indices_shape, &b_)));
2387 // And add the index to our window index. This yields the output index.
2388 llvm::Value* casted_scatter_index =
2389 IntCast(loaded_scatter_index, index.GetType(),
2390 /*isSigned=*/true);
2391 llvm::Value* dim_offset =
2392 Add(input_window_multidim[operand_dim], casted_scatter_index);
2393 input_window_multidim[operand_dim] = dim_offset;
2394
2395 // Also do the bounds check now.
2396 int64_t max_index = desc.operand_shape.dimensions(operand_dim) -
2397 input_window_bounds[operand_dim] + 1;
2398 // is_in_bounds = index >= 0 && index < dim_size-window_size+1
2399 // --> index u< dim_size-window_size+1
2400 is_in_bounds =
2401 And(is_in_bounds, ICmpULT(casted_scatter_index,
2402 index.GetConstantWithIndexType(max_index)));
2403 }
2404
2405 llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse(
2406 is_in_bounds, "scatter.in_bounds", &b_, /*emit_else=*/false);
2407 llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_);
2408 // All done, now just read from the calculated input from the window, and do
2409 // an atomic store to the calculated location in the output.
2410 llvm_ir::IrArray::Index input_window_index(
2411 input_window_multidim, desc.output.GetShape(), index.GetType());
2412 llvm::Value* output_address =
2413 desc.output.EmitArrayElementAddress(input_window_index, &b_);
2414 llvm::Value* input_address = llvm_ir::EmitAllocaAtFunctionEntry(
2415 llvm_ir::PrimitiveTypeToIrType(desc.updates_shape.element_type(),
2416 module_),
2417 "input_address", &b_);
2418 TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
2419 desc.updates_gen(index));
2420 Store(input_ir_value, input_address);
2421
2422 if (!desc.unique_indices) {
2423 return EmitAtomicOperationForNestedComputation(
2424 *desc.update_computation, output_address, input_address,
2425 desc.output.GetElementLlvmType());
2426 } else {
2427 return EmitCallToNestedComputation(*desc.update_computation,
2428 {output_address, input_address},
2429 output_address);
2430 }
2431 };
2432
2433 // Launch a kernel that reads every element in the updates tensor. We could
2434 // also do one kernel per window instead if bounds checks turn out to be a
2435 // bottleneck.
2436 return ParallelLoopEmitter(loop_body_emitter, desc.updates_shape,
2437 launch_dimensions, &b_)
2438 .EmitLoop(desc.name,
2439 desc.get_index_type(launch_dimensions.launch_bound()));
2440 }
2441
2442 // This transformation should be migrated off. See b/171334474.
2443 StatusOr<HloComputation*>
GetOrCreateSubComputationFromRegion(mlir::Region * region,bool is_fusion)2444 IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region,
2445 bool is_fusion) {
2446 std::unique_ptr<HloModule>& module = scratch_nested_computations_[region];
2447 if (module == nullptr) {
2448 std::vector<Shape> operand_shapes, output_shapes;
2449 if (is_fusion) {
2450 mlir::Operation* clone = region->getParentOp()->clone();
2451 region = &mlir::cast<mlir::lmhlo::FusionOp>(clone).getRegion();
2452 TF_RETURN_IF_ERROR(
2453 ProcessFusionForConversion(region, &operand_shapes, &output_shapes));
2454 }
2455
2456 xla::XlaComputation xla_computation;
2457 mlir::MlirToHloConversionOptions options;
2458 options.propagate_layouts = true;
2459 options.propagate_bitcast_layouts_to_backend_config = true;
2460 options.legalize_node_names = false;
2461 TF_RETURN_IF_ERROR(
2462 ConvertRegionToComputation(region, &xla_computation, options));
2463
2464 if (is_fusion) {
2465 region->getParentOp()->erase();
2466 }
2467
2468 TF_ASSIGN_OR_RETURN(auto program_shape, xla_computation.GetProgramShape());
2469 TF_ASSIGN_OR_RETURN(
2470 module, HloModule::CreateFromProto(xla_computation.proto(),
2471 HloModuleConfig(program_shape)));
2472
2473 if (is_fusion) {
2474 HloComputation* fused_computation = module->entry_computation();
2475
2476 CHECK_EQ(operand_shapes.size(), fused_computation->num_parameters());
2477 for (int i = 0; i < fused_computation->num_parameters(); i++) {
2478 *fused_computation->parameter_instruction(i)
2479 ->mutable_shape()
2480 ->mutable_layout() = operand_shapes[i].layout();
2481 }
2482 HloInstruction* root = fused_computation->root_instruction();
2483 // Manually fold Tuple(GTE(a, 0), GTE(a, 1), GTE(a, 2), ...) to a.
2484 // FusedIrEmitter doesn't take GTE ops because we aim to elimiate tuples
2485 // as much as possible.
2486 if (root->opcode() == HloOpcode::kTuple) {
2487 [&] {
2488 HloInstruction* real_root = nullptr;
2489 int expected_tuple_index = 0;
2490 for (HloInstruction* operand : root->operands()) {
2491 if (operand->opcode() != HloOpcode::kGetTupleElement) {
2492 return;
2493 }
2494 if (real_root == nullptr) {
2495 real_root = operand->mutable_operand(0);
2496 } else if (real_root != operand->operand(0)) {
2497 return;
2498 }
2499 if (expected_tuple_index != operand->tuple_index()) {
2500 return;
2501 }
2502 expected_tuple_index++;
2503 }
2504 fused_computation->set_root_instruction(real_root);
2505 std::vector<HloInstruction*> to_be_removed;
2506 to_be_removed.push_back(root);
2507 for (HloInstruction* operand : root->operands()) {
2508 to_be_removed.push_back(operand);
2509 }
2510 for (auto instr : to_be_removed) {
2511 TF_CHECK_OK(fused_computation->RemoveInstruction(instr));
2512 }
2513
2514 root = real_root;
2515 }();
2516 }
2517
2518 if (output_shapes.size() > 1) {
2519 CHECK(root->shape().IsTuple());
2520 CHECK_EQ(root->shape().tuple_shapes_size(), output_shapes.size());
2521
2522 for (int i = 0; i < output_shapes.size(); i++) {
2523 *root->mutable_shape()->mutable_tuple_shapes(i) = output_shapes.at(i);
2524 }
2525 } else {
2526 CHECK_EQ(1, output_shapes.size());
2527 *root->mutable_shape() = output_shapes[0];
2528 }
2529 }
2530 // Post-process the generated computation:
2531 // * Sanitize constant names, so that they can be used as LLVM global
2532 // symbols.
2533 // * Propagate layouts for tuple types.
2534 for (HloComputation* computation : module->computations()) {
2535 for (HloInstruction* instr : computation->MakeInstructionPostOrder()) {
2536 if (instr->opcode() == HloOpcode::kConstant) {
2537 // Notice that IR emitters use the name of constants as LLVM symbol
2538 // names, therefore it's important to not let these constants in the
2539 // new module collide with constants in the original module by names.
2540 // Unique them by prepending the module name.
2541 //
2542 // TODO(timshen): A better solution would be to plumb the exact
2543 // constant names through original HLO -> LHLO -> MHLO -> HLO. This is
2544 // hard because XLA builder doesn't support setting names. Revisit
2545 // this once we get rid of this function, or don't rely on the op name
2546 // (which shouldn't be the identity) to generate LLVM symbols.
2547 instr->SetAndSanitizeName(llvm_ir::SanitizeConstantName(
2548 module->name() + "_" + instr->name()));
2549 }
2550 }
2551 }
2552 }
2553 return module->entry_computation();
2554 }
2555
EmitSort(mlir::Operation * op)2556 Status IrEmitterUnnested::EmitSort(mlir::Operation* op) {
2557 auto sort_op = mlir::cast<mlir::lmhlo::SortOp>(op);
2558 ThunkSequence thunks;
2559
2560 std::string op_name = GetIrNameFromLoc(sort_op.getLoc());
2561 std::vector<mlir::Value> operands = GetHloOperands(sort_op);
2562 const Shape& keys_shape = GetShape(operands[0]);
2563 int64_t dimension_to_sort = sort_op.getDimension();
2564 for (int64_t i = 0; i < operands.size(); ++i) {
2565 // We assume that the layout of all involved operands and outputs is the
2566 // same.
2567 TF_RET_CHECK(
2568 LayoutUtil::LayoutsInShapesEqual(keys_shape, GetShape(operands[i])));
2569 TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
2570 keys_shape, GetShape(GetHloOutputs(sort_op)[i])));
2571
2572 // If possible, we share buffers. If that is not possible, we need to copy
2573 // the values, because the emitter does the sorting in-place.
2574 TF_ASSIGN_OR_RETURN(auto destination_buffer,
2575 GetAllocationSlice(sort_op.getOutput()[i]));
2576 TF_ASSIGN_OR_RETURN(auto source_address,
2577 GetAllocationSlice(sort_op.getOperands()[i]));
2578 if (destination_buffer != source_address) {
2579 // TODO(b/26783907): Figure out why we never seem to share buffers for
2580 // key/value sort.
2581 VLOG(2) << op_name << " requires initial D2D copy for operand " << i;
2582 thunks.push_back(std::make_unique<DeviceToDeviceCopyThunk>(
2583 Thunk::ThunkInfo(),
2584 /*source_address=*/source_address,
2585 /*destination_buffer=*/destination_buffer,
2586 /*mem_size=*/ShapeUtil::ByteSizeOf(GetShape(operands[i]))));
2587 }
2588 }
2589
2590 uint64_t dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
2591 int64_t num_stages = Log2Ceiling(dimension_to_sort_bound);
2592 VLOG(2) << op_name << " requires " << num_stages << " stages.";
2593 CHECK_GE(1ULL << num_stages, dimension_to_sort_bound);
2594 CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound);
2595
2596 // Naive C++ code for the outer loops:
2597 //
2598 // for (int64_t stage = 0; stage < Log2Ceiling(dimension_to_sort_bound);
2599 // ++stage) {
2600 // int64_t first_xor_mask = (1LL << (stage + 1)) - 1;
2601 // SortInPlace(first_xor_mask);
2602 // for (int64_t mask = stage - 1; mask >= 0; --mask) {
2603 // int64_t later_xor_mask = 1LL << mask;
2604 // SortInPlace(later_xor_mask);
2605 // }
2606 // }
2607 //
2608 // This follows the alternative representation of the algorithm described on
2609 // Wikipedia: https://en.wikipedia.org/wiki/Bitonic_sorter
2610 //
2611 // Each mask specifies how to derive from one position in the array the
2612 // position with which it should be compared (we calculate the xor of the
2613 // position with the mask).
2614 // As an optimization, we can move the 'mask' loop to inside the
2615 // sorting/comparison loop if the comparisons happen within a small block of
2616 // the array. To make this work, we collect all consecutive masks that are
2617 // smaller than our chosen power of 2 tile size, and pass them to SortInPlace.
2618 // Each thread then processes one tile of data.
2619
2620 const uint64_t kTileSize = std::min(2048ULL, 1ULL << num_stages);
2621
2622 // If we cannot combine several xor masks together, we don't use tiling, so we
2623 // calculate the standard launch dimensions for the shape. However we only
2624 // need to iterate through ~half of the dimension to sort (rounded up to the
2625 // next highest power of 2), because each iteration compares one pair of
2626 // elements.
2627 Shape standard_iteration_shape = keys_shape;
2628 uint64_t standard_num_iterations_in_sort_dim = 1ULL << (num_stages - 1);
2629 standard_iteration_shape.set_dimensions(dimension_to_sort,
2630 standard_num_iterations_in_sort_dim);
2631 TF_ASSIGN_OR_RETURN(
2632 LaunchDimensions standard_launch_dimensions,
2633 CalculateLaunchDimensions(standard_iteration_shape,
2634 ir_emitter_context_->gpu_device_info()));
2635
2636 // Calculate the launch dimensions for the case where we use tiling. We split
2637 // the dimension that should be sorted into tiles of size 'kTileSize'. This
2638 // means we first need to round 'dimension_to_sort_bound' up to be a multiple
2639 // of the tile size.
2640 int64_t rounded_bound = RoundUpTo(dimension_to_sort_bound, kTileSize);
2641 Shape iteration_shape = keys_shape;
2642
2643 // We iterate through the element pairs that should be compared.
2644 uint64_t num_iterations_in_sort_dim = rounded_bound / 2;
2645 iteration_shape.set_dimensions(dimension_to_sort, num_iterations_in_sort_dim);
2646 uint64_t num_iterations = ShapeUtil::ElementsIn(iteration_shape);
2647
2648 // For correctness reasons we need exactly 'kTileSize' / 2 many threads per
2649 // block. Each thread is responsible for copying exactly two adjacent elements
2650 // into shared memory, and then does a comparison of two possibly different
2651 // elements taken from shared memory.
2652 const uint64_t kThreadsPerBlock = kTileSize / 2;
2653
2654 // Check whether we should use any tiling. We might not be able to use it if
2655 // we have not enough threads, or not enough shared memory. Also it does not
2656 // give a speedup if the tile size is < 128.
2657 int64_t total_shared_memory_needed = 0;
2658 for (int64_t i = 0; i < operands.size(); ++i) {
2659 total_shared_memory_needed +=
2660 kTileSize * ShapeUtil::ByteSizeOfPrimitiveType(
2661 GetShape(operands[i]).element_type());
2662 }
2663 bool no_tiling =
2664 kTileSize < 128 ||
2665 kThreadsPerBlock >
2666 ir_emitter_context_->gpu_device_info().threads_per_block_limit ||
2667 total_shared_memory_needed >
2668 ir_emitter_context_->gpu_device_info().shared_memory_per_block;
2669 VLOG(2) << absl::StreamFormat(
2670 "%s %s use tiling. No tiling if any of the following is true: "
2671 "kTileSize=%d < 128, "
2672 "kThreadsPerBlock=%d > threads_per_block_limit=%d, "
2673 "total_shared_memory_needed=%d > shared_memory_per_block=%d",
2674 op_name, (no_tiling ? "won't" : "will"), kTileSize, kThreadsPerBlock,
2675 ir_emitter_context_->gpu_device_info().threads_per_block_limit,
2676 total_shared_memory_needed,
2677 ir_emitter_context_->gpu_device_info().shared_memory_per_block);
2678
2679 uint64_t num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock);
2680 LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock);
2681 VLOG(2) << absl::StreamFormat("%s launch dims: %d blocks, %d threads/block",
2682 op_name, num_blocks, kThreadsPerBlock);
2683
2684 std::vector<llvm_ir::IrArray> ir_arrays;
2685 auto emit_kernel = [&](absl::Span<const int64_t> xor_masks) {
2686 VLOG(2) << absl::StreamFormat(
2687 "%s uses kernel for xor masks [%s]", op_name,
2688 absl::StrJoin(xor_masks, ", ", [](std::string* out, int64_t xor_mask) {
2689 absl::StrAppendFormat(out, "0x%x", xor_mask);
2690 }));
2691 thunks.emplace_back();
2692 LaunchDimensions launch_dimensions = xor_masks.size() > 1
2693 ? tiled_launch_dimensions
2694 : standard_launch_dimensions;
2695 TF_ASSIGN_OR_RETURN(
2696 thunks.back(),
2697 BuildKernelThunk(sort_op, sort_op.getOutput(), Thunk::ThunkInfo(),
2698 &ir_arrays, launch_dimensions));
2699 std::vector<IrArray> values_arrays;
2700 values_arrays.reserve(operands.size());
2701 for (int64_t i = 0; i < operands.size(); ++i) {
2702 values_arrays.push_back(ir_arrays[i]);
2703 }
2704 TF_ASSIGN_OR_RETURN(const HloComputation* comparator,
2705 GetOrCreateSubComputationFromRegion(
2706 &sort_op.getComparator(), /*is_fusion=*/false));
2707 return llvm_ir::EmitSortInPlace(
2708 dimension_to_sort, values_arrays, IrName(op_name), xor_masks, &b_,
2709 launch_dimensions,
2710 xor_masks.size() > 1 ? num_iterations_in_sort_dim
2711 : standard_num_iterations_in_sort_dim,
2712 kTileSize,
2713 [&](absl::Span<llvm::Value* const> operands, llvm::Value* output) {
2714 return EmitCallToNestedComputation(*comparator, operands, output);
2715 });
2716 };
2717 std::vector<int64_t> xor_masks;
2718 for (int64_t stage = 0; stage < num_stages; ++stage) {
2719 for (int64_t mask = stage; mask >= 0; --mask) {
2720 int64_t xor_mask;
2721 if (mask == stage) {
2722 xor_mask = (1LL << (stage + 1)) - 1;
2723 } else {
2724 xor_mask = 1LL << mask;
2725 }
2726 if (xor_mask >= kTileSize || no_tiling) {
2727 if (!xor_masks.empty()) {
2728 TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
2729 xor_masks.clear();
2730 }
2731 TF_RETURN_IF_ERROR(emit_kernel({xor_mask}));
2732 } else {
2733 xor_masks.push_back(xor_mask);
2734 }
2735 }
2736 }
2737 if (!xor_masks.empty()) {
2738 TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
2739 }
2740 VLOG(2) << absl::StreamFormat(
2741 "%s requires %d thunks (including any D2D copies)", op_name,
2742 thunks.size());
2743
2744 AddThunkToThunkSequence(
2745 std::make_unique<SequentialThunk>(GetThunkInfo(op), std::move(thunks)));
2746 return OkStatus();
2747 }
2748
2749 template <typename ThunkType, typename OpT>
EmitReplicaOrPartitionId(mlir::Operation * op)2750 Status IrEmitterUnnested::EmitReplicaOrPartitionId(mlir::Operation* op) {
2751 auto casted = mlir::cast<OpT>(op);
2752 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
2753 GetAllocationSlice(casted.getOperand()));
2754 auto thunk = std::make_unique<ThunkType>(GetThunkInfo(op), result_slice);
2755 AddThunkToThunkSequence(std::move(thunk));
2756 return OkStatus();
2757 }
2758
EmitCollectivePermute(mlir::Operation * op)2759 Status IrEmitterUnnested::EmitCollectivePermute(mlir::Operation* op) {
2760 auto collective_permute_op = mlir::cast<mlir::lmhlo::CollectivePermuteOp>(op);
2761
2762 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice source_slice,
2763 GetAllocationSlice(collective_permute_op.getOperand()));
2764 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
2765 GetAllocationSlice(collective_permute_op.getOutput()));
2766
2767 const Shape shape = GetShape(collective_permute_op.getOperand());
2768 const int64_t replica_count = hlo_module_config_.replica_count();
2769 const int64_t partition_count = hlo_module_config_.num_partitions();
2770
2771 if (NcclCollectivePermuteThunk::IsDegenerate(
2772 collective_permute_op, replica_count, partition_count)) {
2773 // For a degenerate collective permute, just generate a copy thunk.
2774 AddThunkToThunkSequence(std::make_unique<DeviceToDeviceCopyThunk>(
2775 GetThunkInfo(op),
2776 /*source_address=*/source_slice,
2777 /*destination_buffer=*/result_slice,
2778 /*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
2779 } else {
2780 const NcclCollectivePermuteThunk::Buffer buffer = {
2781 /*element_count=*/ShapeUtil::ElementsIn(shape),
2782 /*source_buffer=*/source_slice,
2783 /*destination_buffer=*/result_slice};
2784 auto thunk = std::make_unique<NcclCollectivePermuteThunk>(
2785 GetThunkInfo(op), collective_permute_op, replica_count, partition_count,
2786 buffer);
2787 AddThunkToThunkSequence(std::move(thunk));
2788 }
2789 return OkStatus();
2790 }
2791
MaybeAddAllReduceStartThunkToMap(absl::flat_hash_map<mlir::Operation *,NcclAllReduceStartThunk * > & all_reduce_start_thunks,mlir::Operation * op,Thunk * thunk)2792 Status MaybeAddAllReduceStartThunkToMap(
2793 absl::flat_hash_map<mlir::Operation*, NcclAllReduceStartThunk*>&
2794 all_reduce_start_thunks,
2795 mlir::Operation* op, Thunk* thunk) {
2796 if (mlir::isa<mlir::lmhlo_gpu::AllReduceStartOp>(op)) {
2797 TF_RET_CHECK(all_reduce_start_thunks
2798 .emplace(op, static_cast<NcclAllReduceStartThunk*>(thunk))
2799 .second)
2800 << "all-reduce-start with this unique ID already seen";
2801 }
2802 return OkStatus();
2803 }
2804
2805 template <typename NcclThunkType, typename OpTy>
EmitNcclThunk(mlir::Operation * untyped_op)2806 Status IrEmitterUnnested::EmitNcclThunk(mlir::Operation* untyped_op) {
2807 OpTy op = mlir::cast<OpTy>(untyped_op);
2808 int64_t replica_count = hlo_module_config_.replica_count();
2809 int64_t partition_count = hlo_module_config_.num_partitions();
2810 VLOG(2) << NcclThunkType::GetName() << "; replica count: " << replica_count
2811 << "; partition count: " << partition_count
2812 << "; operand count: " << op.getOperands().size()
2813 << "; NCCL is enabled: " << NcclThunkType::NcclIsEnabled();
2814
2815 // A given collective op can be degenerate if across all groups formed
2816 // by it are singleton. In such a case, we don't need to do any communication
2817 // and we can just copy the input to the output.
2818 bool is_degenerate =
2819 NcclThunkType::IsDegenerate(op, replica_count, partition_count);
2820 bool should_use_nccl_thunk =
2821 !is_degenerate && NcclThunkType::CanImplement(op);
2822
2823 // Stash relevant information in NcclCollectiveThunk::Buffer even if we may
2824 // not generate an NcclCollectiveThunk.
2825 std::vector<NcclCollectiveThunk::Buffer> buffers;
2826 buffers.reserve(op.getOperands().size());
2827 for (auto it : llvm::zip(op.getInputs(), op.getOutputs())) {
2828 mlir::Value operand = std::get<0>(it);
2829 mlir::Value result = std::get<1>(it);
2830 const Shape shape = GetShape(operand);
2831 TF_ASSIGN_OR_RETURN(auto source_slice, GetAllocationSlice(operand));
2832 TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSlice(result));
2833 buffers.push_back(NcclCollectiveThunk::Buffer{
2834 /*element_count=*/ShapeUtil::ElementsIn(shape),
2835 /*source_buffer=*/source_slice,
2836 /*destination_buffer=*/dest_slice});
2837 }
2838
2839 if (should_use_nccl_thunk) {
2840 auto thunk =
2841 std::make_unique<NcclThunkType>(GetThunkInfo(op), op,
2842 /*buffers=*/std::move(buffers));
2843 // Record thunks for all-reduce-start ops as the done ops need them.
2844 TF_RETURN_IF_ERROR(MaybeAddAllReduceStartThunkToMap(
2845 all_reduce_start_thunks_, op, thunk.get()));
2846 AddThunkToThunkSequence(std::move(thunk));
2847 return OkStatus();
2848 }
2849
2850 // Signal that all-reduce-start thunk not created with nullptr.
2851 TF_RETURN_IF_ERROR(
2852 MaybeAddAllReduceStartThunkToMap(all_reduce_start_thunks_, op, nullptr));
2853
2854 if (!is_degenerate) {
2855 CollectiveOpGroupMode group_mode = NcclThunkType::GetGroupMode(op);
2856
2857 std::string message = absl::StrFormat(
2858 "Requested %s not implemented on GPU; replica_count: %d; "
2859 "partition_count: %d, group_mode: %s, operand_count: %d; NCCL support: "
2860 "%d",
2861 NcclThunkType::GetName(), replica_count, partition_count,
2862 CollectiveOpGroupModeToString(group_mode), op.getOperands().size(),
2863 NcclThunkType::NcclIsEnabled());
2864 if (!op.getOperands().empty()) {
2865 const Shape shape = GetShape(op.getOperands().front());
2866 absl::StrAppendFormat(&message, "; first operand array element-type: %s",
2867 PrimitiveType_Name(shape.element_type()));
2868 }
2869 return Unimplemented("%s", message);
2870 }
2871
2872 VLOG(1) << "Collective call is degenerate, not doing NCCL call";
2873
2874 // All-gather with one replica is simply the identity function. Buffer
2875 // assignment expects a copy, so that's what we do.
2876 ThunkSequence thunks;
2877 for (int64_t i = 0; i < buffers.size(); i++) {
2878 const Shape shape = GetShape(op.getOperands()[i]);
2879 thunks.push_back(std::make_unique<DeviceToDeviceCopyThunk>(
2880 buffers.size() == 1 ? GetThunkInfo(op) : Thunk::ThunkInfo(),
2881 /*source_address=*/buffers[i].source_buffer,
2882 /*destination_buffer=*/buffers[i].destination_buffer,
2883 /*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
2884 }
2885 if (thunks.size() == 1) {
2886 AddThunkToThunkSequence(std::move(thunks[0]));
2887 } else {
2888 AddThunkToThunkSequence(
2889 std::make_unique<SequentialThunk>(GetThunkInfo(op), std::move(thunks)));
2890 }
2891 return OkStatus();
2892 }
2893
EmitAllReduceDone(mlir::Operation * op)2894 Status IrEmitterUnnested::EmitAllReduceDone(mlir::Operation* op) {
2895 auto done_op = mlir::cast<mlir::lmhlo_gpu::AllReduceDoneOp>(op);
2896 auto start_op =
2897 done_op.getToken().getDefiningOp<mlir::lmhlo_gpu::AllReduceStartOp>();
2898 auto it = all_reduce_start_thunks_.find(start_op);
2899 TF_RET_CHECK(it != all_reduce_start_thunks_.end())
2900 << "couldn't find thunk for all-reduce-start op";
2901
2902 // Can be null if no all-reduce-start thunk was created (e.g. if the start op
2903 // is degenerate), in which case there's nothing to do here.
2904 if (it->second != nullptr) {
2905 AddThunkToThunkSequence(std::make_unique<NcclAllReduceDoneThunk>(
2906 GetThunkInfo(op), *it->second));
2907 }
2908 all_reduce_start_thunks_.erase(it);
2909 return OkStatus();
2910 }
2911
GetShapedSlices(mlir::Operation::operand_range operands)2912 StatusOr<std::vector<ShapedSlice>> IrEmitterUnnested::GetShapedSlices(
2913 mlir::Operation::operand_range operands) {
2914 std::vector<ShapedSlice> shaped_slices;
2915 shaped_slices.reserve(operands.size());
2916 for (mlir::Value opnd : operands) {
2917 TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(opnd));
2918 shaped_slices.push_back(ShapedSlice{slice, GetShape(opnd)});
2919 }
2920 return shaped_slices;
2921 }
2922
GetSlices(mlir::Operation::operand_range operands)2923 StatusOr<std::vector<BufferAllocation::Slice>> IrEmitterUnnested::GetSlices(
2924 mlir::Operation::operand_range operands) {
2925 std::vector<BufferAllocation::Slice> slices;
2926 slices.reserve(operands.size());
2927 for (mlir::Value opnd : operands) {
2928 TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(opnd));
2929 slices.push_back(slice);
2930 }
2931 return slices;
2932 }
2933
EmitInfeed(mlir::Operation * op)2934 Status IrEmitterUnnested::EmitInfeed(mlir::Operation* op) {
2935 mlir::Operation::operand_range operands =
2936 mlir::cast<mlir::lmhlo::InfeedOp>(op).getOutputs();
2937 TF_ASSIGN_OR_RETURN(auto shaped_slices, GetShapedSlices(operands));
2938 auto thunk =
2939 std::make_unique<InfeedThunk>(GetThunkInfo(op), std::move(shaped_slices));
2940 AddThunkToThunkSequence(std::move(thunk));
2941
2942 return OkStatus();
2943 }
2944
EmitOutfeed(mlir::Operation * op)2945 Status IrEmitterUnnested::EmitOutfeed(mlir::Operation* op) {
2946 mlir::Operation::operand_range operands =
2947 mlir::cast<mlir::lmhlo::OutfeedOp>(op).getInputs();
2948 TF_ASSIGN_OR_RETURN(auto shaped_slices, GetShapedSlices(operands));
2949 auto thunk = std::make_unique<OutfeedThunk>(GetThunkInfo(op),
2950 std::move(shaped_slices));
2951 AddThunkToThunkSequence(std::move(thunk));
2952
2953 return OkStatus();
2954 }
2955
BuildKernelThunkImpl(absl::string_view name,Thunk::ThunkInfo thunk_info,absl::Span<const BufferSlice> slices,std::vector<llvm_ir::IrArray> * ir_arrays,const LaunchDimensions & launch_dimensions)2956 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildKernelThunkImpl(
2957 absl::string_view name, Thunk::ThunkInfo thunk_info,
2958 absl::Span<const BufferSlice> slices,
2959 std::vector<llvm_ir::IrArray>* ir_arrays,
2960 const LaunchDimensions& launch_dimensions) {
2961 // Figure out which buffer allocations need to be passed as arguments to our
2962 // kernel. This is simply all of the allocations referenced in slices,
2963 // plus the XLA temp buffer (if we have it). We always include the temp
2964 // buffer because even if the kernel itself doesn't use it, a nested
2965 // subcomputation within the kernel (e.g. a kMap's computation) might.
2966 absl::flat_hash_set<const BufferAllocation*> buffers_needed;
2967 for (const auto& slice : slices) {
2968 buffers_needed.insert(slice.buffer_slice.allocation());
2969 }
2970 std::optional<const BufferAllocation*> temp_buffer;
2971 for (const BufferAllocation& alloc : ir_emitter_context_->allocations()) {
2972 if (alloc.IsPreallocatedTempBuffer()) {
2973 if (!temp_buffer.has_value()) {
2974 // Retrieve the first seen temp buffer.
2975 temp_buffer = &alloc;
2976 }
2977 }
2978 }
2979 if (temp_buffer.has_value()) {
2980 buffers_needed.insert(*temp_buffer);
2981 }
2982
2983 // We'll pass a pointer to each of the elements of `buffers` to our kernel, in
2984 // this order.
2985 std::vector<const BufferAllocation*> non_constant_buffers;
2986 absl::c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers),
2987 [](const BufferAllocation* allocation) {
2988 return !allocation->is_constant();
2989 });
2990
2991 absl::c_sort(non_constant_buffers,
2992 [](const BufferAllocation* a, const BufferAllocation* b) {
2993 return a->index() < b->index();
2994 });
2995
2996 llvm::Function* kernel = BuildKernelPrototype(name, non_constant_buffers);
2997
2998 // Build a map from a BufferAllocation to the corresponding argument in our
2999 // kernel.
3000 absl::flat_hash_map<const BufferAllocation*, llvm::Value*> kernel_args;
3001 {
3002 auto arg_it = kernel->arg_begin();
3003 auto buffers_it = non_constant_buffers.begin();
3004 for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) {
3005 kernel_args[*buffers_it] = arg_it;
3006
3007 // Annotate all allocations with LLVM's `noalias`.
3008 // There are three kinds of allocations:
3009 // * Read-only allocations, aka input parameters that are not aliased with
3010 // outputs.
3011 // * Read-write allocations, including all output buffers, some of which
3012 // may alias with input HLO parameters, but aliased HLO buffers are always
3013 // assigned with the same allocation.
3014 // * The temp buffer.
3015 //
3016 // Read-only allocations may overlap with each other, but since they are
3017 // not mutated, they can always be annotated with `noalias` per LLVM
3018 // semantics.
3019 //
3020 // Read-write allocations and the temp buffer don't overlap with any
3021 // allocations, therefore they can also be annotated with `noalias`.
3022 kernel->addParamAttr(
3023 arg_it->getArgNo(),
3024 llvm::Attribute::get(arg_it->getContext(), llvm::Attribute::NoAlias));
3025 }
3026 }
3027
3028 absl::flat_hash_set<BufferAllocation::Slice> buffers_written;
3029 for (const auto& slice : slices) {
3030 if (slice.written) {
3031 buffers_written.insert(slice.buffer_slice);
3032 }
3033 }
3034
3035 ir_arrays->clear();
3036
3037 // For each buffer our kernel might want to touch, bind it to a value derived
3038 // from our kernel args.
3039 for (const BufferSlice& slice : slices) {
3040 const BufferAllocation::Slice& buffer_slice = slice.buffer_slice;
3041
3042 llvm::Value* loc;
3043 if (!slice.constant_name.empty()) {
3044 loc = module_->getGlobalVariable(slice.constant_name);
3045 CHECK_NE(loc, nullptr)
3046 << "Could not find variable '" << slice.constant_name << "'";
3047 } else {
3048 CHECK(!buffer_slice.allocation()->is_constant());
3049 loc =
3050 InBoundsGEP(b_.getInt8Ty(), kernel_args.at(buffer_slice.allocation()),
3051 {b_.getInt64(buffer_slice.offset())});
3052 }
3053
3054 llvm::Type* ir_type = llvm_ir::ShapeToIrType(slice.shape, module_);
3055 llvm_ir::IrArray ir_array(CastToTypedValue(slice.shape, loc, &b_), ir_type,
3056 slice.shape);
3057 if (!buffers_written.contains(slice.buffer_slice)) {
3058 ir_array.MarkInvariantOverWholeProgram(&loc->getContext());
3059 }
3060
3061 ir_arrays->push_back(ir_array);
3062 }
3063
3064 AnnotateThunkLaunchDimensions(launch_dimensions,
3065 std::string(kernel->getName()), module_);
3066
3067 return {std::make_unique<KernelThunk>(thunk_info, non_constant_buffers,
3068 std::string(kernel->getName()),
3069 launch_dimensions)};
3070 }
3071
BuildKernelThunk(mlir::Operation * op,mlir::ValueRange operands,Thunk::ThunkInfo thunk_info,std::vector<llvm_ir::IrArray> * ir_arrays,const LaunchDimensions & launch_dimensions)3072 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildKernelThunk(
3073 mlir::Operation* op, mlir::ValueRange operands, Thunk::ThunkInfo thunk_info,
3074 std::vector<llvm_ir::IrArray>* ir_arrays,
3075 const LaunchDimensions& launch_dimensions) {
3076 TF_RET_CHECK(!mlir::isa<mlir::lmhlo::FusionOp>(op));
3077
3078 std::vector<BufferSlice> slices;
3079 slices.reserve(operands.size());
3080 for (mlir::Value operand : operands) {
3081 slices.emplace_back();
3082 auto& slice = slices.back();
3083 TF_ASSIGN_OR_RETURN(slice.buffer_slice,
3084 GetAllocationSlice(operand, &slice.constant_name));
3085 slice.written = WritesMlirBuffer(op, operand);
3086 slice.shape = GetShape(operand);
3087 }
3088 std::string name = GetIrNameFromLoc(op->getLoc());
3089 return BuildKernelThunkImpl(name, thunk_info, slices, ir_arrays,
3090 launch_dimensions);
3091 }
3092
BuildKernelThunk(mlir::Operation * op,Thunk::ThunkInfo thunk_info,std::vector<llvm_ir::IrArray> * ir_arrays,const LaunchDimensions & launch_dimensions)3093 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildKernelThunk(
3094 mlir::Operation* op, Thunk::ThunkInfo thunk_info,
3095 std::vector<llvm_ir::IrArray>* ir_arrays,
3096 const LaunchDimensions& launch_dimensions) {
3097 if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
3098 auto operands = GetHloOperands(op);
3099 auto outputs = GetHloOutputs(op);
3100
3101 std::vector<BufferSlice> slices;
3102 slices.reserve(operands.size() + outputs.size());
3103 for (mlir::Value operand : operands) {
3104 slices.emplace_back();
3105 BufferSlice& slice = slices.back();
3106 TF_ASSIGN_OR_RETURN(slice.buffer_slice,
3107 GetAllocationSlice(operand, &slice.constant_name));
3108 slice.written = false;
3109 slice.shape = GetShape(operand);
3110 }
3111 for (mlir::Value output : outputs) {
3112 slices.emplace_back();
3113 BufferSlice& slice = slices.back();
3114 TF_ASSIGN_OR_RETURN(slice.buffer_slice,
3115 GetAllocationSlice(output, &slice.constant_name));
3116 slice.written = true;
3117 slice.shape = GetShape(output);
3118 }
3119 std::string name = GetIrNameFromLoc(op->getLoc());
3120 return BuildKernelThunkImpl(name, thunk_info, slices, ir_arrays,
3121 launch_dimensions);
3122 }
3123 return BuildKernelThunk(op, op->getOperands(), thunk_info, ir_arrays,
3124 launch_dimensions);
3125 }
3126
BuildConstantInitializerThunk(absl::Span<const uint8_t> init_value,const BufferAllocation::Slice & dest,const Shape & output_shape)3127 std::unique_ptr<Thunk> IrEmitterUnnested::BuildConstantInitializerThunk(
3128 absl::Span<const uint8_t> init_value, const BufferAllocation::Slice& dest,
3129 const Shape& output_shape) {
3130 int64_t num_bytes = init_value.size();
3131 if (absl::c_all_of(init_value, [](uint8_t byte) { return byte == 0; })) {
3132 return std::make_unique<MemzeroThunk>(Thunk::ThunkInfo(), dest);
3133 }
3134
3135 // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
3136 // repeating the literal 4 or 2 times, so long as the destination buffer is
3137 // an even multiple of 32 bits long.
3138 if ((num_bytes == 1 || num_bytes == 2) &&
3139 ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) {
3140 uint16_t pattern16;
3141 if (num_bytes == 1) {
3142 uint8_t b = init_value.front();
3143 pattern16 = uint16_t{b} | (uint16_t{b} << 8);
3144 } else {
3145 memcpy(&pattern16, init_value.data(), sizeof(pattern16));
3146 }
3147 uint32_t pattern32 = uint32_t{pattern16} | (uint32_t{pattern16} << 16);
3148 return std::make_unique<Memset32BitValueThunk>(Thunk::ThunkInfo(),
3149 pattern32, dest);
3150 }
3151
3152 // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
3153 // memset so long as all 32-bit words of the scalar are equal to each other.
3154 if (num_bytes >= 4 && num_bytes % 4 == 0 &&
3155 memcmp(init_value.data(), init_value.data() + 4, init_value.size() - 4) ==
3156 0) {
3157 uint32_t word;
3158 memcpy(&word, init_value.data(), sizeof(word));
3159 return std::make_unique<Memset32BitValueThunk>(Thunk::ThunkInfo(), word,
3160 dest);
3161 }
3162
3163 return nullptr;
3164 }
3165
3166 StatusOr<std::unique_ptr<Thunk>>
TryBuildConstantInitializerThunk(mlir::Value init_value,mlir::Value dest)3167 IrEmitterUnnested::TryBuildConstantInitializerThunk(mlir::Value init_value,
3168 mlir::Value dest) {
3169 mlir::DenseElementsAttr const_init;
3170 if (auto get_global_memref =
3171 mlir::dyn_cast_or_null<mlir::memref::GetGlobalOp>(
3172 init_value.getDefiningOp())) {
3173 auto global_memref =
3174 mlir::SymbolTable::lookupNearestSymbolFrom<mlir::memref::GlobalOp>(
3175 get_global_memref, get_global_memref.getNameAttr());
3176 if (global_memref.getConstant() && global_memref.getInitialValue()) {
3177 // If the initial value happens to be a constant, generate a specialized
3178 // thunk.
3179 const_init = global_memref.getInitialValue()
3180 .getValue()
3181 .cast<mlir::DenseElementsAttr>();
3182 }
3183 } else if (auto constant = mlir::dyn_cast_or_null<mlir::mhlo::ConstantOp>(
3184 init_value.getDefiningOp())) {
3185 const_init = constant.value().dyn_cast<mlir::DenseElementsAttr>();
3186 }
3187
3188 if (const_init) {
3189 std::vector<uint8_t> literal_bytes;
3190 TF_RETURN_IF_ERROR(
3191 CopyDenseElementsDataToXlaFormat(const_init, &literal_bytes));
3192
3193 TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSlice(dest));
3194
3195 const Shape dest_shape = GetShape(dest);
3196 auto thunk =
3197 BuildConstantInitializerThunk(literal_bytes, dest_slice, dest_shape);
3198 if (thunk) {
3199 return {std::move(thunk)};
3200 }
3201 }
3202 return std::unique_ptr<Thunk>();
3203 }
3204
BuildInitializerThunk(mlir::Operation * op,mlir::Value init_value,mlir::Value dest)3205 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
3206 mlir::Operation* op, mlir::Value init_value, mlir::Value dest) {
3207 // initial value must be a scalar memref.
3208 auto init_type = init_value.getType().dyn_cast<mlir::MemRefType>();
3209 TF_RET_CHECK(init_type.getRank() == 0);
3210
3211 TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> constant_init_thunk,
3212 TryBuildConstantInitializerThunk(init_value, dest));
3213 if (constant_init_thunk) {
3214 return {std::move(constant_init_thunk)};
3215 }
3216
3217 // Otherwise fall back to our slow initializer code. The thunk in this case
3218 // will just need the IR arrays for the initial value and the destination.
3219 const Shape dest_shape = GetShape(dest);
3220 TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
3221 CalculateLaunchDimensions(
3222 dest_shape, ir_emitter_context_->gpu_device_info()));
3223 std::vector<llvm_ir::IrArray> ir_arrays;
3224 TF_ASSIGN_OR_RETURN(
3225 std::unique_ptr<Thunk> kernel_thunk,
3226 BuildKernelThunk(op, {init_value, dest}, Thunk::ThunkInfo(), &ir_arrays,
3227 launch_dimensions));
3228
3229 const llvm_ir::IrArray init_array = ir_arrays[0];
3230 const llvm_ir::IrArray dest_array = ir_arrays[1];
3231
3232 std::string name = GetIrNameFromLoc(op->getLoc());
3233 TF_RETURN_IF_ERROR(ParallelLoopEmitter(
3234 [=](const IrArray::Index& index) {
3235 return init_array.EmitReadArrayElement(index, &b_);
3236 },
3237 {dest_array}, launch_dimensions, &b_)
3238 .EmitLoop(GetIrNameFromLoc(op->getLoc())));
3239
3240 return std::move(kernel_thunk);
3241 }
3242
BuildFusedInitializerThunk(mlir::lmhlo::FusionOp fusion,int output_index)3243 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildFusedInitializerThunk(
3244 mlir::lmhlo::FusionOp fusion, int output_index) {
3245 auto reduce = mlir::dyn_cast_or_null<mlir::mhlo::ReduceOp>(
3246 fusion.getFusionRoots()[output_index]);
3247
3248 TF_RET_CHECK(reduce);
3249 TF_RET_CHECK(reduce.getNumResults() == 1);
3250
3251 mlir::Value init_value = reduce.init_values()[0];
3252 mlir::Value dest = fusion.getOutputBuffers()[output_index];
3253 TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> constant_init_thunk,
3254 TryBuildConstantInitializerThunk(init_value, dest));
3255 if (constant_init_thunk) {
3256 return {std::move(constant_init_thunk)};
3257 }
3258
3259 auto input_buffers = fusion.getInputBuffers();
3260
3261 const Shape dest_shape = GetShape(dest);
3262 TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
3263 CalculateLaunchDimensions(
3264 dest_shape, ir_emitter_context_->gpu_device_info()));
3265 std::vector<llvm_ir::IrArray> ir_arrays;
3266 TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> kernel_thunk,
3267 BuildKernelThunk(fusion, Thunk::ThunkInfo(), &ir_arrays,
3268 launch_dimensions));
3269
3270 const llvm_ir::IrArray dest_array =
3271 ir_arrays[input_buffers.size() + output_index];
3272
3273 const HloComputation* fused_computation =
3274 *GetOrCreateSubComputationFromRegion(&fusion.getRegion(),
3275 /*is_fusion=*/true);
3276
3277 // If init_value was fused into this reduce we have to generate it first.
3278 GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
3279 GetNestedComputer());
3280
3281 FusedIrEmitter fused_emitter(elemental_emitter);
3282 for (int i = 0; i < fused_computation->num_parameters(); i++) {
3283 fused_emitter.BindGenerator(
3284 *fused_computation->parameter_instruction(i),
3285 [this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
3286 return ir_arrays[i].EmitReadArrayElement(index, &b_);
3287 });
3288 }
3289 HloInstruction* instr = fused_computation->root_instruction();
3290 if (instr->opcode() != HloOpcode::kTuple) {
3291 CHECK_EQ(0, output_index);
3292 } else {
3293 instr = instr->mutable_operand(output_index);
3294 }
3295 TF_RET_CHECK(instr->shape().IsArray());
3296 TF_ASSIGN_OR_RETURN(auto generator,
3297 fused_emitter.GetGenerator(*instr->operand(1)));
3298 TF_RETURN_IF_ERROR(
3299 ParallelLoopEmitter(generator, {dest_array}, launch_dimensions, &b_)
3300 .EmitLoop(GetIrNameFromLoc(fusion.getLoc())));
3301 return {std::move(kernel_thunk)};
3302 }
3303
BuildWhileThunk(mlir::lmhlo::WhileOp while_op,const Thunk::ThunkInfo & thunk_info)3304 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildWhileThunk(
3305 mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info) {
3306 // Generate thunk sequence for while 'condition'.
3307 mlir::Region* condition = &while_op.getCond();
3308 TF_ASSIGN_OR_RETURN(
3309 auto ir_emitter_condition,
3310 IrEmitterUnnested::Create(hlo_module_config_, ir_emitter_context_));
3311
3312 TF_RETURN_IF_ERROR(ir_emitter_condition->EmitLmhloRegion(condition));
3313
3314 // Generate thunk sequence for while 'body'.
3315 mlir::Region* body = &while_op.getBody();
3316 TF_ASSIGN_OR_RETURN(
3317 auto ir_emitter_body,
3318 IrEmitterUnnested::Create(hlo_module_config_, ir_emitter_context_));
3319
3320 TF_RETURN_IF_ERROR(ir_emitter_body->EmitLmhloRegion(body));
3321
3322 // Extract the condition value from the last op (exlucidng the terminator op)
3323 // in the condition region.
3324 auto cond_result = GetHloOutputs(while_op);
3325 TF_RET_CHECK(cond_result.size() == 1);
3326 TF_ASSIGN_OR_RETURN(auto cond_result_slice,
3327 GetAllocationSlice(cond_result[0]));
3328
3329 return std::unique_ptr<Thunk>(
3330 new WhileThunk(thunk_info, cond_result_slice,
3331 ir_emitter_condition->ConsumeThunkSequence(),
3332 ir_emitter_body->ConsumeThunkSequence()));
3333 }
3334
BuildForThunk(mlir::lmhlo::WhileOp while_op,const Thunk::ThunkInfo & thunk_info,const int64_t loop_limit)3335 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildForThunk(
3336 mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info,
3337 const int64_t loop_limit) {
3338 // Generate thunk sequence for while 'body' (will be used a For loop body).
3339 TF_ASSIGN_OR_RETURN(
3340 auto ir_emitter_body,
3341 IrEmitterUnnested::Create(hlo_module_config_, ir_emitter_context_));
3342 TF_RETURN_IF_ERROR(ir_emitter_body->EmitLmhloRegion(&while_op.getBody()));
3343
3344 return std::unique_ptr<Thunk>(new ForThunk(
3345 thunk_info, loop_limit, ir_emitter_body->ConsumeThunkSequence()));
3346 }
3347
EmitTargetElementLoop(const HloInstruction & hlo,const llvm_ir::ElementGenerator & body_emitter)3348 Status IrEmitterUnnested::EmitTargetElementLoop(
3349 const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter) {
3350 return InternalError("This should be unreachable");
3351 }
3352
3353 // Gets the output offset as calculated from thread_id.x (to be applied to the
3354 // offset calculated from block_id and thread_id.y).
GetStartOffsetX(const TilingScheme & tiling_scheme,llvm::Value * thread_id_x,llvm::Type * index_ty,llvm::IRBuilder<> * b)3355 static llvm::Value* GetStartOffsetX(const TilingScheme& tiling_scheme,
3356 llvm::Value* thread_id_x,
3357 llvm::Type* index_ty,
3358 llvm::IRBuilder<>* b) {
3359 int64_t multiplier = tiling_scheme.GetIndexingOrder() == kStridedIndexingX
3360 ? tiling_scheme.GetVectorSize()
3361 : tiling_scheme.GetTileSizeFor(kDimX);
3362 return b->CreateMul(thread_id_x,
3363 llvm::ConstantInt::get(index_ty, multiplier));
3364 }
3365
3366 // Emits loop through the minor (X) dimension of a tile, starting at a given
3367 // offset.
3368 //
3369 // Rough pseudocode:
3370 //
3371 // Given: offset, callback
3372 //
3373 // for (int x = 0; x < x_tile_size / vector_size; ++x) {
3374 // for (int i = 0; i < vector_size; ++i) {
3375 // callback(offset + x * stride * vector_size + i);
3376 // }
3377 // }
EmitXTileLoop(const IrEmitterUnnested::ThreadIdInfo & thread_id_info,const TilingScheme & tiling_scheme,bool check_x_tile_bounds,llvm::Value * start_offset_x,llvm::Value * y_loc,IrEmitterUnnested::ValueVector3 tile_dimensions,const IrArray::Index & source_idx,llvm::IRBuilder<> * b,const IrEmitterUnnested::EmitElementFunction * emit_elem_function)3378 static void EmitXTileLoop(
3379 const IrEmitterUnnested::ThreadIdInfo& thread_id_info,
3380 const TilingScheme& tiling_scheme, bool check_x_tile_bounds,
3381 llvm::Value* start_offset_x, llvm::Value* y_loc,
3382 IrEmitterUnnested::ValueVector3 tile_dimensions,
3383 const IrArray::Index& source_idx, llvm::IRBuilder<>* b,
3384 const IrEmitterUnnested::EmitElementFunction* emit_elem_function) {
3385 llvm::Type* index_ty = tile_dimensions[kDimX]->getType();
3386 KernelSupportLibrary ksl(b, llvm_ir::UnrollMode::kDefaultUnroll);
3387 auto constant = [&](int64_t val) {
3388 return llvm::ConstantInt::get(index_ty, val);
3389 };
3390
3391 IrArray::Index source_idx_x_base = source_idx.AddOffsetToDim(y_loc, kDimY, b);
3392 int64_t vector_size = tiling_scheme.GetVectorSize();
3393 int64_t stride_x = tiling_scheme.GetIndexingOrder() == kLinearIndexingX
3394 ? 1
3395 : tiling_scheme.GetNumThreadsFor(kDimX);
3396 KernelSupportLibrary unrolled_ksl(b, llvm_ir::UnrollMode::kFullyUnroll);
3397 unrolled_ksl.For(
3398 "tile_loop",
3399 /*start=*/constant(0),
3400 /*end=*/constant(tiling_scheme.GetTileSizeFor(kDimX) / vector_size),
3401 /*step=*/1, [&](llvm::Value* x) {
3402 for (int64_t i = 0; i < vector_size; i++) {
3403 llvm::Value* linear_index =
3404 b->CreateAdd(b->CreateMul(x, constant(vector_size)), constant(i));
3405 llvm::Value* x_loc = b->CreateAdd(
3406 b->CreateAdd(b->CreateMul(x, constant(stride_x * vector_size)),
3407 constant(i)),
3408 start_offset_x, "x_loc");
3409 IrArray::Index source_idx_x = source_idx_x_base.AddOffsetToDim(
3410 b->CreateAdd(b->CreateMul(x, constant(stride_x * vector_size)),
3411 constant(i)),
3412 kDimX, b);
3413 auto emit_element = [&] {
3414 return (*emit_elem_function)(thread_id_info, source_idx_x, y_loc,
3415 x_loc, linear_index);
3416 };
3417 if (check_x_tile_bounds) {
3418 ksl.If("x_in_tile", b->CreateICmpULT(x_loc, tile_dimensions[kDimX]),
3419 emit_element);
3420 } else {
3421 emit_element();
3422 }
3423 }
3424 });
3425 }
3426
EmitTile(const TilingScheme & tiling_scheme,const IrArray::Index & tile_origin_index,const ThreadIdInfo & thread_id_info,ValueVector3 tile_dimensions,const IrEmitterUnnested::EmitElementFunction & emit_elem_function)3427 void IrEmitterUnnested::EmitTile(
3428 const TilingScheme& tiling_scheme, const IrArray::Index& tile_origin_index,
3429 const ThreadIdInfo& thread_id_info, ValueVector3 tile_dimensions,
3430 const IrEmitterUnnested::EmitElementFunction& emit_elem_function) {
3431 llvm::Type* index_ty = tile_dimensions[kDimY]->getType();
3432 auto constant = [&](int64_t val) {
3433 return llvm::ConstantInt::get(index_ty, val);
3434 };
3435 llvm::Value* num_threads_y = constant(tiling_scheme.GetNumThreadsFor(kDimY));
3436 llvm::Value* start_offset_x =
3437 GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_x, index_ty, &b_);
3438
3439 KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
3440 IrArray::Index source_idx =
3441 tile_origin_index.AddOffsetToDim(start_offset_x, kDimX, &b_);
3442
3443 ksl.For(
3444 "y_in_tile",
3445 /*start=*/thread_id_info.thread_id_y,
3446 /*end=*/
3447 tile_dimensions[kDimY],
3448 /*step=*/num_threads_y, [&](llvm::Value* y_loc) {
3449 auto unroll_inner_tile_loop = [&](bool check_x_tile_bounds) {
3450 return EmitXTileLoop(thread_id_info, tiling_scheme,
3451 check_x_tile_bounds, start_offset_x, y_loc,
3452 tile_dimensions, source_idx, &b_,
3453 &emit_elem_function);
3454 };
3455
3456 // Only take this path when we unroll in a way vectorizable by
3457 // LLVM. Special case when the tile doesn't fit completely for even
3458 // row size. For odd row size every other row isn't aligned to the
3459 // vectorized size, so it can't be vectorized by LLVM.
3460 if (tiling_scheme.GetIndexingOrder() == kStridedIndexingX) {
3461 ksl.If(
3462 "is_full_tile",
3463 b_.CreateICmpEQ(
3464 constant(tiling_scheme.GetBlockTileSizeFor(kDimX)),
3465 tile_dimensions[kDimX]),
3466 [&] { unroll_inner_tile_loop(/*check_x_tile_bounds=*/false); },
3467 [&] { unroll_inner_tile_loop(/*check_x_tile_bounds=*/true); });
3468 } else {
3469 unroll_inner_tile_loop(/*check_x_tile_bounds=*/true);
3470 }
3471 });
3472 }
3473
GetUnnormalizedIndex(const IrArray::Index & normalized_shape_index,const Shape & unnormalized_shape,llvm::IRBuilder<> * b_,const TilingScheme & tiling_scheme)3474 static IrArray::Index GetUnnormalizedIndex(
3475 const IrArray::Index& normalized_shape_index,
3476 const Shape& unnormalized_shape, llvm::IRBuilder<>* b_,
3477 const TilingScheme& tiling_scheme) {
3478 DCHECK_EQ(normalized_shape_index.size(), 3);
3479 // If the normalization only add a new dimensions of size 1,
3480 // generate simpler indexing. LLVM doesn't always simplify the more
3481 // complicated indexing and this prevents it from vectorizing some
3482 // cases. We do this only for major_to_minor memory layout.
3483 if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() &&
3484 unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[1] &&
3485 unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[2] &&
3486 unnormalized_shape.layout().minor_to_major(1) == 0) {
3487 CHECK_EQ(normalized_shape_index.dims()[0], 1);
3488 auto multidim = normalized_shape_index.multidim();
3489 return IrArray::Index({multidim[1], multidim[2]}, unnormalized_shape,
3490 normalized_shape_index.GetType());
3491 }
3492 if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() &&
3493 unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[2] &&
3494 unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[1] &&
3495 unnormalized_shape.layout().minor_to_major(1) == 1) {
3496 CHECK_EQ(normalized_shape_index.dims()[0], 1);
3497 auto multidim = normalized_shape_index.multidim();
3498 return IrArray::Index({multidim[2], multidim[1]}, unnormalized_shape,
3499 normalized_shape_index.GetType());
3500 }
3501 llvm::Value* linear =
3502 normalized_shape_index.Linearize(tiling_scheme.GetDimsInElems(), b_);
3503 return IrArray::Index(linear, unnormalized_shape, b_);
3504 }
3505
3506 // Emits code to process a tensor element in a tile for the given kLoop fusion
3507 // HLO containing parameters that are 0-2-1 transpose of its outputs.
3508 //
3509 // index: The index for the first output element in the normalized tensor, that
3510 // is the resulting tensor after collapsing contiguous dimensions that play
3511 // the same role in the transpose.
3512 // kernel_info: Other information to support the kernel code generation.
EmitTileElementForTranspose(const ThreadIdInfo & thread_id_info,mlir::lmhlo::FusionOp fusion,absl::Span<const llvm_ir::IrArray> operand_arrays,absl::Span<const llvm_ir::IrArray> output_arrays,const llvm_ir::IrArray::Index & index,const TilingScheme & tiling_scheme,llvm::Value * y_loc,llvm::Value * x_loc,absl::Span<llvm::Value * const> param_shmem_buffers)3513 void IrEmitterUnnested::EmitTileElementForTranspose(
3514 const ThreadIdInfo& thread_id_info, mlir::lmhlo::FusionOp fusion,
3515 absl::Span<const llvm_ir::IrArray> operand_arrays,
3516 absl::Span<const llvm_ir::IrArray> output_arrays,
3517 const llvm_ir::IrArray::Index& index, const TilingScheme& tiling_scheme,
3518 llvm::Value* y_loc, llvm::Value* x_loc,
3519 absl::Span<llvm::Value* const> param_shmem_buffers) {
3520 const HloComputation* fused_computation =
3521 *GetOrCreateSubComputationFromRegion(&fusion.getRegion(),
3522 /*is_fusion=*/true);
3523 GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
3524 GetNestedComputer());
3525 FusedIrEmitter fused_emitter(elem_emitter);
3526 for (int i = 0; i < operand_arrays.size(); i++) {
3527 llvm_ir::ElementGenerator gen;
3528 if (auto* param_tile_buffer =
3529 llvm::cast_or_null<llvm::GlobalVariable>(param_shmem_buffers[i])) {
3530 gen = [this, param_tile_buffer, x_loc, y_loc,
3531 thread_id_info](llvm_ir::IrArray::Index index) {
3532 // TODO(jlebar): Add AA metadata to this load. Tile buffers are
3533 // global variables, so LLVM's points-to analysis doesn't help us
3534 // much. And we want the AA info to be present before address
3535 // spaces are inferred (which is pretty late in the pipeline), so
3536 // even if we had address-space-based AA in LLVM, it wouldn't help
3537 // us much here.
3538 std::vector<llvm::Value*> idx = {x_loc, y_loc};
3539 auto gep =
3540 thread_id_info.GEPIntoSharedMemory(&b_, param_tile_buffer, idx);
3541 auto type =
3542 thread_id_info.GEPIntoSharedMemoryType(param_tile_buffer, idx);
3543 return Load(type, gep, "tiled_buffer");
3544 };
3545 } else {
3546 auto array = operand_arrays[i];
3547 auto name = fused_computation->parameter_instruction(i)->name();
3548 gen = [this, array, name](const llvm_ir::IrArray::Index& index) {
3549 return array.EmitReadArrayElement(index, &b_, name);
3550 };
3551 }
3552 fused_emitter.BindGenerator(*fused_computation->parameter_instruction(i),
3553 std::move(gen));
3554 }
3555 IrArray::Index untiled_index = GetUnnormalizedIndex(
3556 index, output_arrays[0].GetShape(), &b_, tiling_scheme);
3557 llvm_ir::ElementGenerator output_generator =
3558 *fused_emitter.GetGenerator(*fused_computation->root_instruction());
3559 llvm::Value* output_value = output_generator(untiled_index).ValueOrDie();
3560 if (output_arrays.size() > 1) {
3561 DCHECK(output_value->getType()->isStructTy());
3562 DCHECK_EQ(output_value->getType()->getStructNumElements(),
3563 output_arrays.size());
3564 for (int64_t i = 0; i < output_arrays.size(); ++i) {
3565 output_arrays[i].EmitWriteArrayElement(
3566 untiled_index, ExtractValue(output_value, i), &b_);
3567 }
3568 } else {
3569 output_arrays[0].EmitWriteArrayElement(untiled_index, output_value, &b_);
3570 }
3571 }
3572
GetNumOutputs(const Shape & shape)3573 static int GetNumOutputs(const Shape& shape) {
3574 if (shape.IsTuple()) {
3575 return shape.tuple_shapes_size();
3576 }
3577 return 1;
3578 }
3579
GenerateReductionCodegenState(mlir::lmhlo::FusionOp fusion,const ReductionCodegenInfo & reduction_info,absl::Span<const HloReduceInstruction * const> reduce_instr_index_group,FusedIrEmitter & fused_emitter)3580 ReductionCodegenState IrEmitterUnnested::GenerateReductionCodegenState(
3581 mlir::lmhlo::FusionOp fusion, const ReductionCodegenInfo& reduction_info,
3582 absl::Span<const HloReduceInstruction* const> reduce_instr_index_group,
3583 FusedIrEmitter& fused_emitter) {
3584 ReductionCodegenState reduction_codegen_state(reduction_info);
3585 VLOG(10) << "Emit prologue for reduction: " << MlirToString(fusion);
3586
3587 for (const HloReduceInstruction* reduce_hlo : reduce_instr_index_group) {
3588 int num_partial_results = reduction_codegen_state.GetNumPartialResults();
3589 for (int op_result_idx = 0;
3590 op_result_idx < GetNumOutputs(reduce_hlo->shape()); op_result_idx++) {
3591 Shape result_shape = reduce_hlo->shape().IsTuple()
3592 ? reduce_hlo->shape().tuple_shapes(op_result_idx)
3593 : reduce_hlo->shape();
3594
3595 llvm::Type* element_type =
3596 llvm_ir::PrimitiveTypeToIrType(result_shape.element_type(), module_);
3597 llvm::AllocaInst* reduction_input_address =
3598 llvm_ir::EmitAllocaAtFunctionEntry(element_type,
3599 "reduction_input_address", &b_);
3600
3601 llvm::AllocaInst* partial_result_address =
3602 llvm_ir::EmitAllocaAtFunctionEntryWithCount(
3603 element_type, /*element_count=*/b_.getInt32(num_partial_results),
3604 "partial_reduction_result", &b_);
3605
3606 const HloInstruction* init_value =
3607 reduce_hlo->init_values()[op_result_idx];
3608
3609 // Initialize the partial result with the initial value of the reduction.
3610 llvm::Value* init_ir_value = (*fused_emitter.GetGenerator(*init_value))(
3611 IrArray::Index(b_.getInt32Ty()))
3612 .ValueOrDie();
3613
3614 for (int i = 0; i < num_partial_results; ++i) {
3615 b_.CreateStore(init_ir_value,
3616 InBoundsGEP(partial_result_address->getAllocatedType(),
3617 partial_result_address, {b_.getInt32(i)}));
3618 }
3619
3620 const TilingScheme& tiling_scheme =
3621 reduction_codegen_state.GetTilingScheme();
3622 int64_t num_threads_x = tiling_scheme.GetNumThreadsFor(kDimX);
3623 llvm::GlobalVariable* shared_cache = [&]() -> llvm::GlobalVariable* {
3624 if (reduction_codegen_state.IsRowReduction()) {
3625 // Multi-row reductions do not use shared memory.
3626 if (RowReductionGetRowsPerWarp(tiling_scheme.GetDimsInElems()[2]) >
3627 1) {
3628 return nullptr;
3629 }
3630 // Allocate __shared__
3631 // cache[num_partial_results][num_warps][scaling_factor].
3632 CHECK_EQ(tiling_scheme.GetNumThreadsPerBlock() % WarpSize(), 0);
3633 int num_warps = tiling_scheme.GetNumThreadsPerBlock() / WarpSize();
3634 return AllocateShared(tiling_scheme, element_type,
3635 {num_partial_results, num_warps},
3636 "shared_cache");
3637 } else {
3638 // Allocate __shared__
3639 // cache[num_partial_results][num_threads][num_threads + 1], where
3640 // num_threads == num_threads_x == num_threads_y. The "+1" is used to
3641 // avoid bank conflicts.
3642 CHECK_EQ(num_threads_x, tiling_scheme.GetNumThreadsFor(kDimY));
3643 return AllocateShared(
3644 tiling_scheme, element_type,
3645 {num_partial_results, num_threads_x, num_threads_x + 1},
3646 "shared_cache");
3647 }
3648 }();
3649
3650 llvm_ir::ElementGenerator input_gen =
3651 *fused_emitter.GetGenerator(*reduce_hlo->inputs()[op_result_idx]);
3652 reduction_codegen_state.SetCalculationStateFor(
3653 {shared_cache, init_ir_value, partial_result_address,
3654 reduction_input_address, input_gen},
3655 reduce_hlo, op_result_idx);
3656 }
3657 }
3658
3659 return reduction_codegen_state;
3660 }
3661
EmitFullWarpShuffleDownLoopForReduce(const HloComputation * reducer,absl::Span<std::pair<llvm::Value * const,llvm::Type * const>> partial_result_addresses,int threads_per_block,int num_results_per_warp)3662 void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForReduce(
3663 const HloComputation* reducer,
3664 absl::Span<std::pair<llvm::Value* const, llvm::Type* const>>
3665 partial_result_addresses,
3666 int threads_per_block, int num_results_per_warp) {
3667 // This only works when the block size is a multiple of 32 threads.
3668
3669 // We check this here as a mistake in the number of threads per
3670 // block is very hard to detect.
3671 CHECK_EQ(threads_per_block % 32, 0);
3672 CHECK_EQ(WarpSize() % num_results_per_warp, 0);
3673
3674 for (int distance = 16 / num_results_per_warp; distance >= 1; distance /= 2) {
3675 absl::InlinedVector<llvm::Value*, 2> reduction_params;
3676
3677 for (auto acc : partial_result_addresses) {
3678 reduction_params.push_back(acc.first);
3679 }
3680
3681 for (auto i : partial_result_addresses) {
3682 llvm::Value* partial_result_address = i.first;
3683 llvm::Type* element_type = i.second;
3684
3685 int bit_width = llvm_ir::GetSizeInBits(element_type);
3686 llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry(
3687 element_type, "result_from_other_lane", &b_);
3688
3689 reduction_params.push_back(result_from_other_lane);
3690
3691 // Bitcast cannot be applied to aggregate types (even packed ones), so
3692 // we bitcast addresses of load/store to intN* of the same bit-width.
3693 llvm::Type* shuffled_value_type =
3694 element_type->isStructTy() ? b_.getIntNTy(bit_width) : element_type;
3695 auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) {
3696 return b_.CreatePointerBitCastOrAddrSpaceCast(
3697 ptr, shuffled_value_type->getPointerTo());
3698 };
3699
3700 llvm::Value* partial_result =
3701 b_.CreateLoad(shuffled_value_type,
3702 convert_pointer_for_shuffle(partial_result_address),
3703 "partial_reduction_result");
3704 b_.CreateStore(
3705 EmitFullWarpShuffleDown(partial_result, b_.getInt32(distance), &b_),
3706 convert_pointer_for_shuffle(result_from_other_lane));
3707 }
3708
3709 StatusOr<std::vector<llvm::Value*>> returned_scalars =
3710 ComputeNestedElementFromAddrs(*reducer, reduction_params);
3711 TF_CHECK_OK(returned_scalars.status());
3712
3713 for (int i = 0; i < returned_scalars->size(); i++) {
3714 b_.CreateStore(/*Val=*/returned_scalars->at(i),
3715 /*Ptr=*/partial_result_addresses[i].first);
3716 }
3717 }
3718 }
3719
GetOutputAddressForReduction(int partial_result_idx,llvm::Type * index_ty,const ReductionCodegenState & reduction_codegen_state,const TilingKernelInfo & tiling_kernel_info,const IrEmitterUnnested::ReductionOutputMap & output_arrays,const HloReduceInstruction * reduction,int output_idx)3720 llvm::Value* IrEmitterUnnested::GetOutputAddressForReduction(
3721 int partial_result_idx, llvm::Type* index_ty,
3722 const ReductionCodegenState& reduction_codegen_state,
3723 const TilingKernelInfo& tiling_kernel_info,
3724 const IrEmitterUnnested::ReductionOutputMap& output_arrays,
3725 const HloReduceInstruction* reduction, int output_idx) {
3726 auto constant = [&](uint64_t c) -> llvm::Constant* {
3727 return llvm::ConstantInt::get(index_ty, c);
3728 };
3729
3730 const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme();
3731 const ThreadIdInfo& thread_id_info = tiling_kernel_info.thread_id_info;
3732
3733 IrArray::Index start_offset = [&] {
3734 llvm::Value* x_loc = thread_id_info.thread_id_x;
3735 llvm::Value* y_loc = thread_id_info.thread_id_y;
3736 if (!reduction_codegen_state.IsRowReduction()) {
3737 std::swap(x_loc, y_loc);
3738 }
3739 llvm::Value* start_offset_x =
3740 GetStartOffsetX(tiling_scheme, x_loc, index_ty, &b_);
3741 return tiling_kernel_info.tile_origin.AddOffsetToDim(y_loc, kDimY, &b_)
3742 .AddOffsetToDim(start_offset_x, kDimX, &b_);
3743 }();
3744
3745 const IrArray& output_array = output_arrays.at(reduction)[output_idx];
3746 const Shape& operand_shape = reduction->inputs()[output_idx]->shape();
3747 Shape reduction_kept_element_shape =
3748 ShapeUtil::DeleteDimensions(reduction->dimensions(), operand_shape);
3749
3750 // Given the IrArray index of a reduction input, returns the linear address of
3751 // the reduction output as if the reduction were going to keep the input shape
3752 // with the dimensions being reduced moved.
3753 llvm::Value* untransposed_output_linear_address = [&] {
3754 const llvm_ir::IrArray::Index index =
3755 start_offset.AddOffsetToDim(constant(partial_result_idx), kDimX, &b_);
3756 if (reduction_codegen_state.IsRowReduction()) {
3757 // For row-reduction, y-coordinate determines which row we write into.
3758 return index[kDimY];
3759 }
3760 // For column reduction, we get the transposed address.
3761 absl::Span<const int64_t> dims_in_elem = tiling_scheme.GetDimsInElems();
3762 llvm::Value* x_dim_size =
3763 index.GetConstantWithIndexType(dims_in_elem[kDimX]);
3764 llvm::Value* x_block_offset = b_.CreateMul(index[kDimZ], x_dim_size);
3765 return b_.CreateAdd(x_block_offset, index[kDimX]);
3766 }();
3767
3768 // A reduction is allowed to transpose its output. For example, suppose
3769 // we are reducing the second dimension of f32[10,20,30]{3,2,1}. We are
3770 // allowed to produce as output either f32[10,30]{1,0} (no transpose) or
3771 // f32[10,30]{0,1} (transposing the two output dims).
3772 //
3773 // At this point in the function we have a "partial sum" of input elements
3774 // (stored in partial_result_addresses), and we need to accumulate it into
3775 // the correct output element.
3776 IrArray::Index element_index(
3777 /*linear=*/untransposed_output_linear_address,
3778 reduction_kept_element_shape, &b_);
3779 IrArray::Index output_index(element_index.multidim(), output_array.GetShape(),
3780 element_index.GetType());
3781
3782 return output_array.EmitArrayElementAddress(output_index, &b_,
3783 "output_element_address");
3784 }
3785
EmitBlockId(int32_t num_blocks,llvm::Type * index_ty)3786 llvm::Value* IrEmitterUnnested::EmitBlockId(int32_t num_blocks,
3787 llvm::Type* index_ty) {
3788 llvm::Value* block_id = gpu::EmitCallToTargetIntrinsic(
3789 gpu::TargetIntrinsicID::kBlockIdx, {}, {}, &b_);
3790 if (num_blocks != 0) {
3791 llvm_ir::AddRangeMetadata(0, num_blocks,
3792 llvm::cast<llvm::Instruction>(block_id));
3793 }
3794 llvm::Value* linear_block_id =
3795 b_.CreateIntCast(block_id, index_ty, /*isSigned=*/true, "block.id.x");
3796 return linear_block_id;
3797 }
3798
EmitPrintfWithThreadId(absl::string_view fmt,absl::Span<llvm::Value * const> arguments,std::optional<int64_t> thread_id_filter,std::optional<int64_t> block_id_filter)3799 void IrEmitterUnnested::EmitPrintfWithThreadId(
3800 absl::string_view fmt, absl::Span<llvm::Value* const> arguments,
3801 std::optional<int64_t> thread_id_filter,
3802 std::optional<int64_t> block_id_filter) {
3803 llvm::Value* thread_id = EmitThreadId(
3804 /*threads_per_block=*/1024, b_.getInt32Ty());
3805 llvm::Value* block_id = EmitBlockId(0, b_.getInt32Ty());
3806 std::vector<llvm::Value*> updated_arguments = {thread_id, block_id};
3807 updated_arguments.insert(updated_arguments.end(), arguments.begin(),
3808 arguments.end());
3809 llvm::Value* constraint = b_.getTrue();
3810 if (thread_id_filter) {
3811 constraint = b_.CreateAnd(
3812 constraint, b_.CreateICmpEQ(thread_id, b_.getInt32(*thread_id_filter)));
3813 }
3814 if (block_id_filter) {
3815 constraint = b_.CreateAnd(
3816 constraint, b_.CreateICmpEQ(block_id, b_.getInt32(*block_id_filter)));
3817 }
3818 KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
3819 ksl.If(constraint, [&] {
3820 xla::gpu::EmitPrintf(absl::StrCat("[TID=%d,BID=%d] ", fmt, "\n"),
3821 updated_arguments, &b_);
3822 });
3823 }
3824
CastSharedToGlobal(llvm::Value * input,llvm::Type * element_type,llvm::Twine name)3825 llvm::Value* IrEmitterUnnested::CastSharedToGlobal(llvm::Value* input,
3826 llvm::Type* element_type,
3827 llvm::Twine name) {
3828 return b_.CreateAddrSpaceCast(input,
3829 llvm::PointerType::get(element_type,
3830 /*AddressSpace=*/0),
3831 name);
3832 }
3833
EmitReductionOutputForRowReduction(const TilingKernelInfo & tiling_kernel_info,const ReductionCodegenState & reduction_codegen_state,llvm::Type * index_ty,const ReductionOutputMap & output_arrays,const HloReduceInstruction * reduction,int partial_result_idx)3834 void IrEmitterUnnested::EmitReductionOutputForRowReduction(
3835 const TilingKernelInfo& tiling_kernel_info,
3836 const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty,
3837 const ReductionOutputMap& output_arrays,
3838 const HloReduceInstruction* reduction, int partial_result_idx) {
3839 const HloComputation* reducer = reduction->to_apply();
3840 const auto& thread_id_info = tiling_kernel_info.thread_id_info;
3841 auto constant = [&](uint64_t c) -> llvm::Constant* {
3842 return llvm::ConstantInt::get(index_ty, c);
3843 };
3844 auto is_zero = [&](llvm::Value* value) {
3845 return b_.CreateICmpEQ(value, constant(0));
3846 };
3847
3848 int num_outputs = reducer->num_parameters() / 2;
3849 const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme();
3850 absl::InlinedVector<std::pair<llvm::Value* const, llvm::Type* const>, 2>
3851 current_outputs;
3852 for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
3853 const ReductionCodegenState::ReductionCalculationState& state =
3854 reduction_codegen_state.GetCalculationStateFor(reduction, output_idx);
3855 current_outputs.push_back(
3856 {InBoundsGEP(state.partial_result_address->getAllocatedType(),
3857 state.partial_result_address,
3858 {constant(partial_result_idx)}, "current_output"),
3859 state.partial_result_address->getAllocatedType()});
3860 }
3861
3862 int reduced_dimension_size = tiling_scheme.GetDimsInElems()[2];
3863 int num_rows_per_warp = RowReductionGetRowsPerWarp(reduced_dimension_size);
3864 EmitFullWarpShuffleDownLoopForReduce(
3865 reducer, absl::MakeSpan(current_outputs),
3866 tiling_scheme.GetNumThreadsPerBlockPhysical(), num_rows_per_warp);
3867
3868 KernelSupportLibrary ksl(&b_);
3869 llvm::Value* warp_id =
3870 b_.CreateUDiv(thread_id_info.thread_id_x, constant(WarpSize()));
3871
3872 auto emit_write_output =
3873 [&](llvm::Value* write_condition,
3874 const absl::InlinedVector<
3875 std::pair<llvm::Value* const, llvm::Type* const>, 2>& values) {
3876 ksl.If("reduction_write_output", write_condition, [&] {
3877 for (int oidx = 0; oidx < num_outputs; oidx++) {
3878 llvm::Value* output_address = GetOutputAddressForReduction(
3879 partial_result_idx, index_ty, reduction_codegen_state,
3880 tiling_kernel_info, output_arrays, reduction, oidx);
3881
3882 if (reduction_codegen_state.IsRaceFree()) {
3883 Store(Load(values[oidx].second, values[oidx].first, "output"),
3884 output_address);
3885 } else {
3886 CHECK_EQ(num_outputs, 1);
3887 TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
3888 *reducer, output_address, values[oidx].first,
3889 values[oidx].second));
3890 }
3891 }
3892 });
3893 };
3894
3895 if (num_rows_per_warp > 1) {
3896 llvm::Value* is_writing_thread = is_zero(b_.CreateAnd(
3897 thread_id_info.thread_id_x, constant(reduced_dimension_size - 1)));
3898 emit_write_output(is_writing_thread, current_outputs);
3899 return;
3900 }
3901
3902 ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] {
3903 for (int oidx = 0; oidx < num_outputs; oidx++) {
3904 const ReductionCodegenState::ReductionCalculationState& state =
3905 reduction_codegen_state.GetCalculationStateFor(reduction, oidx);
3906 llvm::Value* shmem_output_addr = thread_id_info.GEPIntoSharedMemory(
3907 &b_, state.shared_cache, {constant(partial_result_idx), warp_id});
3908 Store(Load(current_outputs[oidx].second, current_outputs[oidx].first),
3909 shmem_output_addr);
3910 }
3911 });
3912
3913 // TODO(cheshire): Don't we want to sync it once for everything in the
3914 // output? Not once per each?
3915 EmitSyncThreads();
3916 ksl.If("inter_warp_reduce", is_zero(warp_id), [&] {
3917 absl::InlinedVector<std::pair<llvm::Value* const, llvm::Type* const>, 2>
3918 selected_values;
3919 for (int oidx = 0; oidx < num_outputs; oidx++) {
3920 const ReductionCodegenState::ReductionCalculationState& state =
3921 reduction_codegen_state.GetCalculationStateFor(reduction, oidx);
3922 llvm::Value* block_accum_addr = thread_id_info.GEPIntoSharedMemory(
3923 &b_, state.shared_cache,
3924 {constant(partial_result_idx), thread_id_info.lane_id});
3925
3926 llvm::Type* element_type =
3927 state.partial_result_address->getAllocatedType();
3928
3929 /* Insure initial value address is in generic, not scratch. */
3930 llvm::Value* initial_value_addr =
3931 CastSharedToGlobal(llvm_ir::EmitAllocaAtFunctionEntry(
3932 element_type, "initial_value_addr", &b_),
3933 element_type);
3934 b_.CreateStore(state.initial_value, initial_value_addr);
3935
3936 llvm::Value* warp_exists = b_.CreateICmpULT(
3937 thread_id_info.thread_id_x,
3938 constant(tiling_scheme.GetNumThreadsFor(kDimX) / WarpSize()));
3939
3940 llvm::Value* selected_value =
3941 b_.CreateSelect(warp_exists, block_accum_addr, initial_value_addr);
3942
3943 selected_values.push_back({selected_value, element_type});
3944 }
3945
3946 // If only one warp is present in the block, then we don't need inter-warp
3947 // reduction.
3948 // TODO(b/241414088) If only warp is present, then inter-warp communication
3949 // using shared memory and synchronization using barrier is also unnecessary
3950 // and should be removed.
3951 if (tiling_scheme.GetNumThreadsPerBlock() > WarpSize()) {
3952 EmitFullWarpShuffleDownLoopForReduce(
3953 reducer, absl::MakeSpan(selected_values),
3954 tiling_scheme.GetNumThreadsPerBlock());
3955 }
3956
3957 emit_write_output(is_zero(thread_id_info.thread_id_x), selected_values);
3958 });
3959 }
3960
EmitReductionOutputForColumnReduction(const TilingKernelInfo & tiling_kernel_info,const ReductionCodegenState & reduction_codegen_state,llvm::Type * index_ty,const ReductionOutputMap & output_arrays,const HloReduceInstruction * reduction,int partial_result_idx)3961 void IrEmitterUnnested::EmitReductionOutputForColumnReduction(
3962 const TilingKernelInfo& tiling_kernel_info,
3963 const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty,
3964 const ReductionOutputMap& output_arrays,
3965 const HloReduceInstruction* reduction, int partial_result_idx) {
3966 KernelSupportLibrary ksl(&b_);
3967 const HloComputation* reducer = reduction->to_apply();
3968 const auto& thread_id_info = tiling_kernel_info.thread_id_info;
3969
3970 auto constant = [&](uint64_t c) -> llvm::Constant* {
3971 return llvm::ConstantInt::get(index_ty, c);
3972 };
3973 auto is_zero = [&](llvm::Value* value) {
3974 return b_.CreateICmpEQ(value, constant(0));
3975 };
3976 const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme();
3977 int num_outputs = reducer->num_parameters() / 2;
3978
3979 // Store the transpose in shared memory.
3980 for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
3981 const ReductionCodegenState::ReductionCalculationState& state =
3982 reduction_codegen_state.GetCalculationStateFor(reduction, output_idx);
3983 llvm::GlobalVariable* shared_cache = state.shared_cache;
3984 llvm::AddrSpaceCastInst* shmem_output_addr =
3985 llvm::cast<llvm::AddrSpaceCastInst>(thread_id_info.GEPIntoSharedMemory(
3986 &b_, shared_cache,
3987 {constant(partial_result_idx), thread_id_info.thread_id_x,
3988 thread_id_info.thread_id_y},
3989 "shmem_output_address"));
3990 llvm::Value* current_output =
3991 InBoundsGEP(state.partial_result_address->getAllocatedType(),
3992 state.partial_result_address,
3993 {constant(partial_result_idx)}, "current_output");
3994
3995 llvm::Value* current_output_value =
3996 Load(state.partial_result_address->getAllocatedType(), current_output);
3997 b_.CreateStore(current_output_value, shmem_output_addr);
3998 }
3999
4000 EmitSyncThreads();
4001
4002 // Get transposed element from shared memory.
4003 absl::InlinedVector<std::pair<llvm::Value* const, llvm::Type* const>, 1>
4004 shmem_transposed_addrs;
4005
4006 for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
4007 const ReductionCodegenState::ReductionCalculationState& state =
4008 reduction_codegen_state.GetCalculationStateFor(reduction, output_idx);
4009 llvm::AddrSpaceCastInst* shmem_transposed_addr =
4010 llvm::cast<llvm::AddrSpaceCastInst>(thread_id_info.GEPIntoSharedMemory(
4011 &b_, state.shared_cache,
4012 {constant(partial_result_idx), thread_id_info.thread_id_y,
4013 thread_id_info.thread_id_x},
4014 "shmem_transposed_addr"));
4015 shmem_transposed_addrs.push_back(
4016 {shmem_transposed_addr, llvm::cast<llvm::GetElementPtrInst>(
4017 shmem_transposed_addr->getPointerOperand())
4018 ->getResultElementType()});
4019 }
4020
4021 EmitFullWarpShuffleDownLoopForReduce(reducer,
4022 absl::MakeSpan(shmem_transposed_addrs),
4023 tiling_scheme.GetNumThreadsPerBlock());
4024
4025 // Some warps in the block are completely outside of the bound of the
4026 // tensor, so they should not write any output at all.
4027 llvm::Value* has_output = b_.CreateAnd(
4028 b_.CreateICmpULT(
4029 GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_y, index_ty,
4030 &b_),
4031 tiling_kernel_info.output_tile_bounds[kDimX]),
4032 b_.CreateICmpULT(thread_id_info.thread_id_x,
4033 tiling_kernel_info.output_tile_bounds[kDimY]));
4034
4035 ksl.If("reduction_write_output",
4036 b_.CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] {
4037 for (int oidx = 0; oidx < num_outputs; oidx++) {
4038 llvm::Value* output_address = GetOutputAddressForReduction(
4039 partial_result_idx, index_ty, reduction_codegen_state,
4040 tiling_kernel_info, output_arrays, reduction, oidx);
4041 if (reduction_codegen_state.IsRaceFree()) {
4042 Store(Load(shmem_transposed_addrs[oidx].second,
4043 shmem_transposed_addrs[oidx].first, "output_value"),
4044 output_address);
4045 } else {
4046 CHECK_EQ(num_outputs, 1);
4047 TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
4048 *reducer, output_address, shmem_transposed_addrs[oidx].first,
4049 shmem_transposed_addrs[oidx].second));
4050 }
4051 }
4052 });
4053 }
4054
EmitThreadId(int64_t threads_per_block,llvm::Type * index_ty)4055 llvm::Value* IrEmitterUnnested::EmitThreadId(int64_t threads_per_block,
4056 llvm::Type* index_ty) {
4057 // Calculate (y, x) coordinates respectively in the 2D view of thread block,
4058 // defined by (num_thread_y, num_thread_x) from thread_id.
4059 llvm::CallInst* thread_id_raw = gpu::EmitCallToTargetIntrinsic(
4060 gpu::TargetIntrinsicID::kThreadIdx, {}, {}, &b_);
4061 llvm_ir::AddRangeMetadata(0, threads_per_block, thread_id_raw);
4062 return b_.CreateIntCast(thread_id_raw, index_ty,
4063 /*isSigned=*/true, "thread.id.x");
4064 }
4065
EmitThreadIdInfo(const TilingScheme & tiling_scheme,llvm::Type * index_ty)4066 StatusOr<IrEmitterUnnested::ThreadIdInfo> IrEmitterUnnested::EmitThreadIdInfo(
4067 const TilingScheme& tiling_scheme, llvm::Type* index_ty) {
4068 auto constant = [&](uint64_t c) -> llvm::Constant* {
4069 return llvm::ConstantInt::get(index_ty, c);
4070 };
4071 llvm::Value* thread_id_physical =
4072 EmitThreadId(tiling_scheme.GetNumThreadsPerBlockPhysical(), index_ty);
4073 int64_t num_blocks = tiling_scheme.GetNumberOfBlocksPhysical();
4074 if (num_blocks > (int64_t)std::numeric_limits<uint32_t>::max()) {
4075 return FailedPrecondition(
4076 "Number of physical blocks (%d) does not fit in an i32 in tiling "
4077 "scheme: %s",
4078 num_blocks, tiling_scheme.ToString());
4079 }
4080 llvm::Value* block_id_physical = EmitBlockId(num_blocks, index_ty);
4081
4082 // Wait this will break coalescing.
4083 llvm::Value* thread_id_logical = b_.CreateURem(
4084 thread_id_physical, constant(tiling_scheme.GetNumThreadsPerBlock()));
4085 llvm::Value* scaling = b_.CreateUDiv(
4086 thread_id_physical, constant(tiling_scheme.GetNumThreadsPerBlock()));
4087 llvm::Value* block_id_logical = b_.CreateAdd(
4088 b_.CreateMul(block_id_physical,
4089 constant(tiling_scheme.GetThreadIdScalingFactor())),
4090 scaling);
4091
4092 llvm::Value* num_threads_x_v =
4093 constant(tiling_scheme.GetNumThreadsFor(kDimX));
4094
4095 llvm::Value* block_exists = b_.CreateICmpULT(
4096 block_id_logical, constant(tiling_scheme.GetNumberOfBlocks()));
4097 llvm_ir::EmitEarlyReturn(block_exists, &b_);
4098 return {{thread_id_logical,
4099 /*thread_id_x=*/
4100 b_.CreateURem(thread_id_logical, num_threads_x_v, "thread_id.x"),
4101 /*thread_id_y=*/
4102 b_.CreateUDiv(thread_id_logical, num_threads_x_v, "thread_id.y"),
4103 /*lane_id=*/
4104 b_.CreateURem(thread_id_logical, constant(WarpSize()), "lane_id"),
4105 /*block_id=*/block_id_logical,
4106 /*scaling=*/scaling}};
4107 }
4108
4109 StatusOr<IrEmitterUnnested::TilingKernelInfo>
EmitTilingKernel(const TilingScheme & tiling_scheme,llvm::Type * index_ty,const TileElementGenerator & tile_element_generator)4110 IrEmitterUnnested::EmitTilingKernel(
4111 const TilingScheme& tiling_scheme, llvm::Type* index_ty,
4112 const TileElementGenerator& tile_element_generator) {
4113 absl::Span<const int64_t> dims_in_elems = tiling_scheme.GetDimsInElems();
4114 Vector3 dims_in_blocks = tiling_scheme.GetDimsInBlocks();
4115 auto constant = [&](uint64_t c) -> llvm::Constant* {
4116 return llvm::ConstantInt::get(index_ty, c);
4117 };
4118
4119 TF_ASSIGN_OR_RETURN(ThreadIdInfo thread_id_info,
4120 EmitThreadIdInfo(tiling_scheme, index_ty));
4121
4122 KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
4123
4124 const IrArray::Index block_coords = [&] {
4125 IrArray::Index starting_block(thread_id_info.block_id,
4126 ShapeUtil::MakeShapeWithDescendingLayout(
4127 PRED /*arbitrary*/, dims_in_blocks),
4128 &b_);
4129 std::vector<llvm::Value*> multidim = {
4130 b_.CreateMul(starting_block[0],
4131 constant(tiling_scheme.GetBlockTileSizeFor(0)),
4132 "block_origin.z"),
4133 starting_block[1], starting_block[2]};
4134 return IrArray::Index(multidim, dims_in_blocks, index_ty);
4135 }();
4136
4137 ValueVector3 tile_dimensions;
4138 for (int i = kDimY; i < kDimTot; ++i) {
4139 int64_t tile_size_for_dim = tiling_scheme.GetBlockTileSizeFor(i);
4140 // Only last row or column may not have full size.
4141 llvm::Value* is_last =
4142 b_.CreateICmpEQ(block_coords[i], constant(dims_in_blocks[i] - 1));
4143 int64_t partial_row =
4144 dims_in_elems[i] - (dims_in_blocks[i] - 1) * tile_size_for_dim;
4145 tile_dimensions[i] =
4146 b_.CreateSelect(is_last, constant(partial_row),
4147 constant(tile_size_for_dim), "tile_bound");
4148 }
4149
4150 IrArray::Index tile_origin = [&] {
4151 std::vector<llvm::Value*> elem_multi_index = block_coords.multidim();
4152 llvm::Type* index_ty = block_coords.GetType();
4153 for (int i = kDimY; i < kDimTot; ++i) {
4154 elem_multi_index[i] =
4155 b_.CreateMul(block_coords[i],
4156 llvm::ConstantInt::get(
4157 index_ty, tiling_scheme.GetBlockTileSizeFor(i)),
4158 "tile_origin." + std::to_string(i));
4159 }
4160 return IrArray::Index(elem_multi_index, tiling_scheme.GetDimsInElems(),
4161 index_ty);
4162 }();
4163
4164 auto emit_tile = [&](const IrArray::Index& tile) {
4165 tile_element_generator(thread_id_info, tile, tile_dimensions);
4166 };
4167
4168 if (tiling_scheme.GetBlockTileSizeFor(kDimZ) == 1) {
4169 emit_tile(tile_origin);
4170 } else {
4171 llvm::Value* starting_tile_index_for_dim = tile_origin[kDimZ];
4172 llvm::Value* block_size_for_dim =
4173 constant(tiling_scheme.GetBlockTileSizeFor(kDimZ));
4174 llvm::Value* block_id_for_dim =
4175 b_.CreateUDiv(starting_tile_index_for_dim, block_size_for_dim);
4176 llvm::Value* last_block_for_dim = constant(dims_in_blocks[kDimZ] - 1);
4177 llvm::Value* last_block_size_for_dim = constant(
4178 dims_in_elems[kDimZ] -
4179 (dims_in_blocks[kDimZ] - 1) * tiling_scheme.GetBlockTileSizeFor(kDimZ));
4180
4181 llvm::Value* num_tiles_in_block =
4182 b_.CreateSelect(b_.CreateICmpEQ(last_block_for_dim, block_id_for_dim),
4183 last_block_size_for_dim, block_size_for_dim);
4184 ksl.For("loop_z",
4185 /*start=*/constant(0),
4186 /*end=*/num_tiles_in_block,
4187 /*step=*/1, [&](llvm::Value* block_dim_induction_var) {
4188 IrArray::Index tile_index = tile_origin.AddOffsetToDim(
4189 block_dim_induction_var, kDimZ, &b_);
4190 emit_tile(tile_index);
4191 });
4192 }
4193
4194 return {{tile_dimensions, tile_origin, thread_id_info}};
4195 }
4196
EmitSyncThreads()4197 llvm::CallInst* IrEmitterUnnested::EmitSyncThreads() {
4198 return EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_);
4199 }
4200
EmitHlo021Tile(mlir::lmhlo::FusionOp fusion,absl::Span<const llvm_ir::IrArray> operand_arrays,absl::Span<const llvm_ir::IrArray> output_arrays,TransposeDimsAndParams descr,const TilingScheme & tiling_scheme,const LaunchDimensions & launch_dimensions)4201 Status IrEmitterUnnested::EmitHlo021Tile(
4202 mlir::lmhlo::FusionOp fusion,
4203 absl::Span<const llvm_ir::IrArray> operand_arrays,
4204 absl::Span<const llvm_ir::IrArray> output_arrays,
4205 TransposeDimsAndParams descr, const TilingScheme& tiling_scheme,
4206 const LaunchDimensions& launch_dimensions) {
4207 std::string name = GetIrNameFromLoc(fusion.getLoc());
4208
4209 llvm::Type* index_type = GetIndexTypeForKernel(
4210 fusion.getOperation(), launch_dimensions.launch_bound(), &b_);
4211
4212 // For each tiled parameter, cast its input IrArray to the corresponding
4213 // reduced shape and keep the reduced shape live during IR emission.
4214 std::map<int64_t, IrArray> param_in_reduced_shape_arrays;
4215 std::vector<llvm::Value*> param_shmem_buffers(fusion.getInputBuffers().size(),
4216 nullptr);
4217
4218 auto get_shared_memory_buffer = [&](llvm::Type* elem_ty,
4219 absl::string_view buffer_name) {
4220 // For Nvidia GPUs, the warp size is 32 threads and the shared memory bank
4221 // is organized into 32-way. We usually use the warp size or a multiplier or
4222 // a the warp size as the size for tiling. This may cause all elements in
4223 // the same column of a tile use the same memory bank and therefore shared
4224 // memory bank conflicts. Adding 1 to the minor dimension of the shared
4225 // memory buffer can reduce such shared memory bank conflicts.
4226 return AllocateShared(tiling_scheme, elem_ty,
4227 {tiling_scheme.GetBlockTileSizeFor(kDimY),
4228 tiling_scheme.GetBlockTileSizeFor(kDimX) + 1},
4229 buffer_name);
4230 };
4231
4232 for (int64_t id : descr.params) {
4233 const Shape& param_shape = GetShape(fusion.getInputBuffers()[id]);
4234 param_shmem_buffers[id] = get_shared_memory_buffer(
4235 llvm_ir::PrimitiveTypeToIrType(param_shape.element_type(), module_),
4236 IrName(name, StrCat("tile", id)));
4237 VLOG(3) << "Added shmem buffer for parameter " << id << ": "
4238 << llvm_ir::DumpToString(*param_shmem_buffers[id]);
4239 Shape reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
4240 param_shape.element_type(), Permute(descr.dims, {0, 2, 1}));
4241 param_in_reduced_shape_arrays[id] =
4242 operand_arrays[id].CastToShape(reduced_shape, &b_);
4243 }
4244
4245 TileElementGenerator tile_generator =
4246 [&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index,
4247 ValueVector3 tile_dimensions) {
4248 // If shared memory transpose is needed, wait for all threads to reach
4249 // this point, lest we copy a value from tile to output before the other
4250 // thread copies it from input to tile. This is `__syncthreads` in CUDA.
4251 if (!descr.params.empty()) {
4252 // Calculate the input tile origin from the output tile origin.
4253 const IrArray::Index input_tile_origin(
4254 Permute(index.multidim(), {0, 2, 1}),
4255 Permute(index.dims(), {0, 2, 1}), index.GetType());
4256
4257 ValueVector3 transposed_tile_dimensions = {tile_dimensions[kDimZ],
4258 tile_dimensions[kDimX],
4259 tile_dimensions[kDimY]};
4260
4261 // Copy input parameter values to shared memory buffers:
4262 // tile[thread_id_y, thread_id_x] = input[index]
4263 // Note that tile_width and tile_height are flipped here because we
4264 // are reading a transposed tile.
4265 EmitTile(tiling_scheme, input_tile_origin, thread_id_info,
4266 transposed_tile_dimensions,
4267 [&](const ThreadIdInfo& thread_id_info,
4268 const IrArray::Index& index, llvm::Value* y_loc,
4269 llvm::Value* x_loc, llvm::Value* /*x_iter_num*/) {
4270 for (int64_t id : descr.params) {
4271 IrArray& input_in_logical_shape =
4272 param_in_reduced_shape_arrays[id];
4273
4274 auto shmem_buffer = llvm::cast<llvm::GlobalVariable>(
4275 param_shmem_buffers.at(id));
4276 llvm::Value* value =
4277 input_in_logical_shape.EmitReadArrayElement(
4278 index, &b_, "input_element");
4279
4280 llvm::Value* addr = thread_id_info.GEPIntoSharedMemory(
4281 &b_, shmem_buffer, {y_loc, x_loc});
4282 b_.CreateStore(value, addr);
4283 }
4284 });
4285
4286 // Wait for all threads to reach this point using `__syncthreads` in
4287 // CUDA.
4288 EmitSyncThreads();
4289 }
4290
4291 EmitTile(tiling_scheme, index, thread_id_info, tile_dimensions,
4292 /*emit_elem_function=*/
4293 [&](const ThreadIdInfo& thread_id_info,
4294 const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
4295 llvm::Value* x_loc, llvm::Value* /*x_iter_num*/) {
4296 EmitTileElementForTranspose(
4297 thread_id_info, fusion, operand_arrays, output_arrays,
4298 index, tiling_scheme, y_loc, x_loc, param_shmem_buffers);
4299 });
4300 bool block_contains_multi_tiles =
4301 tiling_scheme.GetBlockTileSizeFor(kDimZ) > 1;
4302
4303 // If a tile block contains multiple tiles and shared memory buffers are
4304 // used, we need to wait for all threads to finish using the shared
4305 // memory buffer for the current tile before we move on to process the
4306 // next tile and overwrite the shared memory buffers.
4307 if (block_contains_multi_tiles && !descr.params.empty()) {
4308 EmitSyncThreads();
4309 }
4310 };
4311
4312 return EmitTilingKernel(tiling_scheme, index_type, tile_generator).status();
4313 }
4314
Emit021Transpose(TransposeDimsAndParams descr,mlir::lmhlo::FusionOp fusion)4315 Status IrEmitterUnnested::Emit021Transpose(TransposeDimsAndParams descr,
4316 mlir::lmhlo::FusionOp fusion) {
4317 constexpr int kNumRows = 4;
4318 CHECK_EQ(WarpSize() % kNumRows, 0);
4319 TilingScheme tiling_scheme(descr.dims,
4320 /*tile_sizes=*/{1, WarpSize() / kNumRows, 1},
4321 /*num_threads=*/{1, kNumRows, WarpSize()},
4322 /*indexing_order=*/kLinearIndexingX,
4323 /*vector_size=*/1,
4324 /*scaling_factor=*/1);
4325 LaunchDimensions launch_dimensions(
4326 tiling_scheme.GetNumberOfBlocksPhysical(),
4327 tiling_scheme.GetNumThreadsPerBlockPhysical());
4328 std::vector<llvm_ir::IrArray> ir_arrays;
4329 TF_ASSIGN_OR_RETURN(auto kernel_thunk,
4330 BuildKernelThunk(fusion, GetThunkInfo(fusion), &ir_arrays,
4331 launch_dimensions));
4332 TF_RETURN_IF_ERROR(EmitHlo021Tile(
4333 fusion,
4334 absl::MakeSpan(ir_arrays).subspan(0, fusion.getInputBuffers().size()),
4335 absl::MakeSpan(ir_arrays).subspan(fusion.getInputBuffers().size()), descr,
4336 tiling_scheme, launch_dimensions));
4337 AddThunkToThunkSequence(std::move(kernel_thunk));
4338 return Status::OK();
4339 }
4340
4341 namespace {
4342
4343 // Returns true if all the transitive users of hlo before hitting users in
4344 // use_chain_endings are elementwise operations.
AreUsersElementwise(mlir::Value value,const absl::flat_hash_set<mlir::Operation * > & use_chain_endings)4345 bool AreUsersElementwise(
4346 mlir::Value value,
4347 const absl::flat_hash_set<mlir::Operation*>& use_chain_endings) {
4348 return absl::c_all_of(value.getUsers(), [&](mlir::OpOperand use) {
4349 mlir::Operation* user = use.getOwner();
4350 return use_chain_endings.count(user) ||
4351 (HloInstruction::IsOpElementwise(*MhloToHloOpcode(user)) &&
4352 absl::c_all_of(user->getResults(),
4353 [&](const mlir::OpResult result) {
4354 return AreUsersElementwise(result,
4355 use_chain_endings);
4356 })
4357
4358 );
4359 });
4360 }
4361
4362 // Returns the number of fusion inputs that have the same dimension as the
4363 // given shape, and involve in only elementwise operations.
NumInputsInvolveInOnlyElementwiseOps(mlir::lmhlo::FusionOp fusion,const Shape & op_shape,const absl::flat_hash_set<mlir::Operation * > & use_chain_endings)4364 int64_t NumInputsInvolveInOnlyElementwiseOps(
4365 mlir::lmhlo::FusionOp fusion, const Shape& op_shape,
4366 const absl::flat_hash_set<mlir::Operation*>& use_chain_endings) {
4367 return absl::c_count_if(
4368 fusion.getFusionParameters(), [&](mlir::Value parameter) {
4369 Shape parameter_shape = GetShape(parameter);
4370 return ShapeUtil::SameDimensions(op_shape, parameter_shape) &&
4371 AreUsersElementwise(parameter, use_chain_endings);
4372 });
4373 }
4374
4375 // Returns the number of fusion inputs that have more elements than the given
4376 // shape.
NumInputsWithMoreElementsThan(mlir::lmhlo::FusionOp fusion,const Shape & shape)4377 int64_t NumInputsWithMoreElementsThan(mlir::lmhlo::FusionOp fusion,
4378 const Shape& shape) {
4379 int64_t num_elements = ShapeUtil::ElementsIn(shape);
4380 return absl::c_count_if(
4381 fusion.getFusionParameters(), [&](mlir::Value parameter) {
4382 Shape parameter_shape = GetShape(parameter);
4383 return ShapeUtil::ElementsIn(parameter_shape) > num_elements;
4384 });
4385 }
4386
4387 // The benefit of unrolling a kInput fusion that is a column reduction comes
4388 // from the vectorization of non-reduction fusion outputs and fusion inputs.
4389 // On the other hand, unrolling can also introduce factors that can cause
4390 // the kernel to run slower. This routine uses a simple heuristic to estimate
4391 // the benefit as well as the overhead of unrolling in order to decide whether
4392 // unrolling is beneficial for the given kInput fusion.
IsUnrollingColumnReductionBeneficial(mlir::lmhlo::FusionOp fusion,const Shape & input_shape,int64_t num_kept_minor)4393 bool IsUnrollingColumnReductionBeneficial(mlir::lmhlo::FusionOp fusion,
4394 const Shape& input_shape,
4395 int64_t num_kept_minor) {
4396 if (num_kept_minor % (WarpSize() * 2) != 0) {
4397 return false;
4398 }
4399
4400 if (input_shape.dimensions()[input_shape.rank() - 1] < 64) {
4401 return false;
4402 }
4403
4404 int64_t can_be_vectorized = 0;
4405 int64_t cannot_be_vectorized = 0;
4406 llvm::SmallVector<mlir::Operation*> fusion_roots = fusion.getFusionRoots();
4407 absl::flat_hash_set<mlir::Operation*> use_chain_endings;
4408 if (fusion_roots.size() == 1) {
4409 if (IsReductionFromOrToContiguousDimensions(fusion_roots[0])) {
4410 use_chain_endings.insert(fusion_roots[0]);
4411 // Atomic.add of the reduction result can't be vectorized.
4412 cannot_be_vectorized++;
4413 }
4414 } else {
4415 for (mlir::Operation* op : fusion_roots) {
4416 if (IsReductionFromOrToContiguousDimensions(op)) {
4417 // Atomic.add of the reduction result can't be vectorized.
4418 cannot_be_vectorized++;
4419 } else {
4420 // Write of the non-reduction result can be vectorized.
4421 can_be_vectorized++;
4422 }
4423 use_chain_endings.insert(op);
4424 }
4425 }
4426 // Fusion inputs that have the same dimension as the reduce input and
4427 // only involve in elementwise operations can be vectorized.
4428 can_be_vectorized += NumInputsInvolveInOnlyElementwiseOps(fusion, input_shape,
4429 use_chain_endings);
4430 // Fusion inputs with more elements than the reduce op input must participate
4431 // in non-elementwise operations and we assume that they are not vectorizable
4432 // for the purpose of estimating the benefit of unrolling. If the kernel is
4433 // unrolled even with such an assumption, and the accesses to those inputs
4434 // turn out to be vectorizable, the compiler will still vectorize them.
4435 cannot_be_vectorized += NumInputsWithMoreElementsThan(fusion, input_shape);
4436 return can_be_vectorized >= cannot_be_vectorized;
4437 }
4438
NearestPowerOfTwo(int64_t v)4439 int64_t NearestPowerOfTwo(int64_t v) {
4440 if (v < 0) {
4441 return 0;
4442 }
4443 int64_t upper = absl::bit_ceil<uint64_t>(v);
4444 int64_t lower = upper >> 1;
4445 return upper - v < v - lower ? upper : lower;
4446 }
4447
4448 } // namespace
4449
4450 // Returns primitive bitwidth for shape of the value.
GetPrimitiveBitwidth(mlir::Value i)4451 static int GetPrimitiveBitwidth(mlir::Value i) {
4452 // TODO(timshen): may not be efficient.
4453 return primitive_util::BitWidth(GetShape(i).element_type());
4454 }
4455
4456 // Experimentally determined values to achieve optimal number of
4457 // bytes-in-flight. With a bound of #warps/SM which can be concurrently
4458 // scheduled, for small reduced values it can be hard to achieve optimal
4459 // number of bytes-in-flight. In order to address it, we increase the # of
4460 // threads/block (physically, while keeping logical mapping the same), which
4461 // allows larger # of bytes-in-flight.
CalculateVirtualThreadScalingFactorForReduction(const ReductionDimensions & reduction_dimensions,const se::CudaComputeCapability & cc)4462 static int CalculateVirtualThreadScalingFactorForReduction(
4463 const ReductionDimensions& reduction_dimensions,
4464 const se::CudaComputeCapability& cc) {
4465 int dimx = reduction_dimensions.dimensions[kDimX];
4466 if (reduction_dimensions.is_row_reduction && dimx <= 128) {
4467 int rows_per_warp = RowReductionGetRowsPerWarp(dimx);
4468 if (cc.IsAtLeast(se::CudaComputeCapability::AMPERE)) {
4469 return rows_per_warp * 3;
4470 }
4471 return rows_per_warp * 5;
4472 }
4473 return 1;
4474 }
4475
GEPIntoSharedMemoryType(llvm::GlobalVariable * shared,absl::Span<llvm::Value * const> idx_major_to_minor) const4476 llvm::Type* IrEmitterUnnested::ThreadIdInfo::GEPIntoSharedMemoryType(
4477 llvm::GlobalVariable* shared,
4478 absl::Span<llvm::Value* const> idx_major_to_minor) const {
4479 std::vector<llvm::Value*> idxs_scaled;
4480 idxs_scaled.push_back(llvm::ConstantInt::get(scaling->getType(), 0));
4481 idxs_scaled.push_back(scaling);
4482 idxs_scaled.insert(idxs_scaled.end(), idx_major_to_minor.begin(),
4483 idx_major_to_minor.end());
4484 return llvm::GetElementPtrInst::getIndexedType(shared->getValueType(),
4485 idxs_scaled);
4486 }
4487
GEPIntoSharedMemory(llvm::IRBuilder<> * b,llvm::GlobalVariable * shared,absl::Span<llvm::Value * const> idx_major_to_minor,const llvm::Twine & name) const4488 llvm::Value* IrEmitterUnnested::ThreadIdInfo::GEPIntoSharedMemory(
4489 llvm::IRBuilder<>* b, llvm::GlobalVariable* shared,
4490 absl::Span<llvm::Value* const> idx_major_to_minor,
4491 const llvm::Twine& name) const {
4492 std::vector<llvm::Value*> idxs_scaled;
4493 idxs_scaled.push_back(llvm::ConstantInt::get(scaling->getType(), 0));
4494 idxs_scaled.push_back(scaling);
4495 idxs_scaled.insert(idxs_scaled.end(), idx_major_to_minor.begin(),
4496 idx_major_to_minor.end());
4497 llvm::Value* gep =
4498 b->CreateInBoundsGEP(shared->getValueType(), shared, idxs_scaled, name);
4499
4500 llvm::PointerType* pointer_in_addressspace =
4501 llvm::PointerType::getWithSamePointeeType(
4502 llvm::cast<llvm::PointerType>(gep->getType()), /*AddressSpace=*/0);
4503
4504 // __shared__ memory uses a different address space, so we cast it to
4505 // global address space before writing or reading.
4506 return b->CreateAddrSpaceCast(gep, pointer_in_addressspace);
4507 }
4508
AllocateShared(const TilingScheme & tiling_scheme,llvm::Type * element_type,absl::Span<int64_t const> dimensions_major_to_minor,absl::string_view buffer_name)4509 llvm::GlobalVariable* IrEmitterUnnested::AllocateShared(
4510 const TilingScheme& tiling_scheme, llvm::Type* element_type,
4511 absl::Span<int64_t const> dimensions_major_to_minor,
4512 absl::string_view buffer_name) {
4513 CHECK(!dimensions_major_to_minor.empty());
4514 llvm::Type* array_type = nullptr;
4515 for (int i = dimensions_major_to_minor.size() - 1; i >= 0; i--) {
4516 // Iterate in minor-to-major order.
4517 int64_t dim = dimensions_major_to_minor[i];
4518 if (!array_type) {
4519 array_type = llvm::ArrayType::get(element_type, dim);
4520 } else {
4521 array_type = llvm::ArrayType::get(array_type, dim);
4522 }
4523 }
4524 array_type = llvm::ArrayType::get(array_type,
4525 tiling_scheme.GetThreadIdScalingFactor());
4526 return llvm_ir::AllocateSharedMemoryTile(b_.GetInsertBlock()->getModule(),
4527 array_type, buffer_name);
4528 }
4529
4530 // Whether the reduction can be vectorized.
CanVectorizeReduction(se::CudaComputeCapability cc,mlir::lmhlo::FusionOp fusion,const ReductionDimensions & reduction_dimensions,int num_threads_x,std::array<int64_t,3> reduction_tiling,const Shape & input_shape)4531 static bool CanVectorizeReduction(
4532 se::CudaComputeCapability cc, mlir::lmhlo::FusionOp fusion,
4533 const ReductionDimensions& reduction_dimensions, int num_threads_x,
4534 std::array<int64_t, 3> reduction_tiling, const Shape& input_shape) {
4535 if (!reduction_dimensions.is_row_reduction) {
4536 return IsUnrollingColumnReductionBeneficial(
4537 fusion, input_shape, reduction_dimensions.dimensions[kDimX]);
4538 }
4539
4540 if (reduction_dimensions.dimensions[kDimX] % 2 != 0 ||
4541 MayPreventVectorization(fusion)) {
4542 return false;
4543 }
4544
4545 // Enabling vectorization if number of threads is <= warpsize leads to half or
4546 // more of the threads not doing any work.
4547 if (reduction_dimensions.is_row_reduction && num_threads_x <= WarpSize()) {
4548 return false;
4549 }
4550
4551 if (cc.IsAtLeast(se::CudaComputeCapability::VOLTA)) {
4552 return true;
4553 }
4554
4555 int smallest_input_dtype_bits = std::numeric_limits<int>::max();
4556 for (mlir::Value operand : fusion.getInputBuffers()) {
4557 smallest_input_dtype_bits =
4558 std::min(GetPrimitiveBitwidth(operand), smallest_input_dtype_bits);
4559 }
4560 if (cc.IsAtLeast(se::CudaComputeCapability::PASCAL_)) {
4561 return smallest_input_dtype_bits <= 32 &&
4562 reduction_dimensions.dimensions[kDimX] %
4563 (reduction_tiling[2] * num_threads_x) ==
4564 0;
4565 }
4566 return false;
4567 }
4568
ComputeReductionCodegenInfo(mlir::lmhlo::FusionOp fusion,mlir::mhlo::ReduceOp first_reduce)4569 StatusOr<ReductionCodegenInfo> IrEmitterUnnested::ComputeReductionCodegenInfo(
4570 mlir::lmhlo::FusionOp fusion, mlir::mhlo::ReduceOp first_reduce) {
4571 Shape input_shape = GetShape(first_reduce->getOperand(0));
4572 ReductionDimensions reduction_dimensions =
4573 GetReductionKindAndContiguousComponents(first_reduce);
4574 VLOG(10) << "is_row_reduction " << reduction_dimensions.is_row_reduction
4575 << " " << reduction_dimensions.dimensions[0] << " "
4576 << reduction_dimensions.dimensions[1] << " "
4577 << reduction_dimensions.dimensions[2];
4578 Vector3 reduction_tiling = GetReductionTiling(
4579 reduction_dimensions, ir_emitter_context_->cuda_compute_capability());
4580
4581 int64_t num_threads_y =
4582 reduction_dimensions.is_row_reduction ? 1 : WarpSize();
4583 int64_t num_threads_x = [&] {
4584 if (reduction_dimensions.is_row_reduction) {
4585 if (RowReductionGetRowsPerWarp(reduction_dimensions.dimensions[2]) > 1) {
4586 return reduction_dimensions.dimensions[2];
4587 }
4588 // Use 512 as default block size (threads per block) for row reductions.
4589 // For multi-output fusions, reduce the block size further to decrease
4590 // register pressure when multiple outputs are computed by each thread.
4591 int64_t fan_out = fusion.getFusionRoots().size();
4592 int64_t max_block_size =
4593 std::max(MinThreadsXRowReduction(),
4594 static_cast<int64_t>(512LL / NearestPowerOfTwo(fan_out)));
4595 return std::min(max_block_size,
4596 RoundUpTo(CeilOfRatio(reduction_dimensions.dimensions[2],
4597 reduction_tiling[2]),
4598 WarpSize()));
4599 }
4600 return WarpSize();
4601 }();
4602
4603 se::CudaComputeCapability cc = ir_emitter_context_->cuda_compute_capability();
4604
4605 int smallest_input_dtype_bits = std::numeric_limits<int>::max();
4606 for (mlir::Value operand : fusion.getInputBuffers()) {
4607 smallest_input_dtype_bits =
4608 std::min(GetPrimitiveBitwidth(operand), smallest_input_dtype_bits);
4609 }
4610
4611 TilingScheme::IndexingOrder indexing_order =
4612 reduction_dimensions.is_row_reduction ? kStridedIndexingX
4613 : kLinearIndexingX;
4614 bool vectorize =
4615 CanVectorizeReduction(cc, fusion, reduction_dimensions, num_threads_x,
4616 reduction_tiling, input_shape);
4617 int vector_size = vectorize ? 2 : 1;
4618 int num_partial_results =
4619 !reduction_dimensions.is_row_reduction && vectorize ? 2 : 1;
4620 VLOG(3) << "Each threads will produce " << num_partial_results
4621 << " output(s)";
4622 reduction_tiling[kDimX] *= num_partial_results;
4623
4624 Vector3 num_threads = {1, num_threads_y, num_threads_x};
4625 int virtual_thread_scaling_factor =
4626 CalculateVirtualThreadScalingFactorForReduction(reduction_dimensions, cc);
4627 VLOG(2) << "Using virtual thread scaling: " << virtual_thread_scaling_factor;
4628
4629 TilingScheme tiling_scheme(reduction_dimensions.dimensions, reduction_tiling,
4630 num_threads, indexing_order, vector_size,
4631 virtual_thread_scaling_factor);
4632 return ReductionCodegenInfo(
4633 tiling_scheme, num_partial_results, reduction_dimensions.is_row_reduction,
4634 ReductionIsRaceFree(reduction_dimensions, reduction_tiling));
4635 }
4636
4637 // Generate a single element of the tile (update the accumulator state) for a
4638 // given reducer of index `i`.
GenerateElementForReducer(const HloReduceInstruction * reduction,llvm::Value * partial_result_index,const ReductionCodegenState & codegen_state,const llvm_ir::IrArray::Index & index_without_linear,const IrArray::Index & input_index,int num_partial_results,const ReductionOutputMap & result_ir_arrays)4639 void IrEmitterUnnested::GenerateElementForReducer(
4640 const HloReduceInstruction* reduction, llvm::Value* partial_result_index,
4641 const ReductionCodegenState& codegen_state,
4642 const llvm_ir::IrArray::Index& index_without_linear,
4643 const IrArray::Index& input_index, int num_partial_results,
4644 const ReductionOutputMap& result_ir_arrays) {
4645 HloComputation* reducer = reduction->to_apply();
4646 CHECK_EQ(reducer->num_parameters() % 2, 0);
4647
4648 absl::InlinedVector<llvm::Value*, 2> reduction_accumulators;
4649 absl::InlinedVector<llvm::Value*, 2> reduction_input_value;
4650 for (int red_idx = 0; red_idx < reducer->num_parameters() / 2; red_idx++) {
4651 const ReductionCodegenState::ReductionCalculationState& state =
4652 codegen_state.GetCalculationStateFor(reduction, red_idx);
4653
4654 llvm::AllocaInst* input_address = state.input_address;
4655 llvm::AllocaInst* partial_reduction_result_address =
4656 state.partial_result_address;
4657 llvm::Value* const input_ir_value = *state.input_gen(
4658 num_partial_results > 1 ? index_without_linear : input_index);
4659 b_.CreateStore(input_ir_value, input_address);
4660 llvm::Value* partial_result_address =
4661 InBoundsGEP(partial_reduction_result_address->getAllocatedType(),
4662 partial_reduction_result_address, {partial_result_index});
4663 reduction_accumulators.push_back(partial_result_address);
4664 reduction_input_value.push_back(input_address);
4665 }
4666
4667 absl::InlinedVector<llvm::Value*, 4> reduction_params;
4668 for (llvm::Value* acc : reduction_accumulators) {
4669 reduction_params.push_back(acc);
4670 }
4671 for (llvm::Value* value : reduction_input_value) {
4672 reduction_params.push_back(value);
4673 }
4674
4675 // Emit a call to the variadic reducer. Since it may be returning a
4676 // tuple, we can't return it directly as a value. Instead, before
4677 // the call, we create N (N = # arguments in the tuple) allocas, one
4678 // for each returned argument, then when we make the call we pass N
4679 // pointers as last parameters, the called computation writes into
4680 // those pointers, and we have returned values on the stack (as well
4681 // as pointers to them).
4682 StatusOr<std::vector<llvm::Value*>> returned_scalars =
4683 ComputeNestedElementFromAddrs(*reducer, reduction_params);
4684 TF_CHECK_OK(returned_scalars.status());
4685
4686 for (int i = 0; i < returned_scalars->size(); i++) {
4687 b_.CreateStore(returned_scalars->at(i), reduction_accumulators[i]);
4688 }
4689 }
4690
EmitIRForReduction(mlir::lmhlo::FusionOp fusion,absl::Span<HloInstruction * const> instr_index_group,FusedIrEmitter & fused_emitter,const ReductionOutputMap & result_ir_arrays,const ReductionCodegenInfo & reduction_info,const Shape & input_shape)4691 Status IrEmitterUnnested::EmitIRForReduction(
4692 mlir::lmhlo::FusionOp fusion,
4693 absl::Span<HloInstruction* const> instr_index_group,
4694 FusedIrEmitter& fused_emitter, const ReductionOutputMap& result_ir_arrays,
4695 const ReductionCodegenInfo& reduction_info, const Shape& input_shape) {
4696 std::vector<const HloReduceInstruction*> reductions;
4697 ExtraOutputGensMap extra_output_gens;
4698 for (const HloInstruction* hlo : instr_index_group) {
4699 if (IsReductionFromOrToContiguousDimensions(*hlo)) {
4700 reductions.push_back(Cast<HloReduceInstruction>(hlo));
4701 } else {
4702 extra_output_gens[hlo] = *fused_emitter.GetGenerator(*hlo);
4703 }
4704 }
4705
4706 CHECK(!reductions.empty()) << " expect at least one reduce instructions.";
4707 const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme();
4708 CHECK_EQ(tiling_scheme.GetNumThreadsPerBlockPhysical() % WarpSize(), 0);
4709 llvm::Type* index_ty =
4710 GetIndexTypeForKernel(fusion,
4711 tiling_scheme.GetNumThreadsPerBlockPhysical() *
4712 tiling_scheme.GetNumberOfBlocksPhysical(),
4713 &b_);
4714 ReductionCodegenState codegen_state = GenerateReductionCodegenState(
4715 fusion, reduction_info, reductions, fused_emitter);
4716
4717 EmitElementFunction emit_reduction_element =
4718 [&](const ThreadIdInfo& thread_id_info,
4719 const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
4720 llvm::Value* x_loc, llvm::Value* x_iter_num) {
4721 IrArray::Index input_index = GetUnnormalizedIndex(
4722 index, input_shape, &b_, codegen_state.GetTilingScheme());
4723
4724 llvm::Value* partial_result_index =
4725 codegen_state.IsRowReduction() ? b_.getInt32(0) : x_iter_num;
4726
4727 // Clear the linear index field of the IrArray::Index to enable the use
4728 // of GetElementPointer with array types. This enables the vectorization
4729 // of the computation for different partial results. Use this index if
4730 // 'num_partial_results > 1'.
4731 int num_partial_results = codegen_state.GetNumPartialResults();
4732 llvm_ir::IrArray::Index index_without_linear = IrArray::Index(
4733 input_index.multidim(), input_shape, input_index.GetType());
4734
4735 // Emit code to generate the input and perform the reduction computation
4736 // for each reduction instruction.
4737 for (const HloReduceInstruction* reduce : reductions) {
4738 GenerateElementForReducer(reduce, partial_result_index, codegen_state,
4739 index_without_linear, input_index,
4740 num_partial_results, result_ir_arrays);
4741 }
4742
4743 // Emit code to generate the output for the non-reduction instructions
4744 // in the fusion, if any.
4745 TF_CHECK_OK(EmitExtraOutputsForReduce(
4746 result_ir_arrays, input_index, reduction_info, extra_output_gens));
4747 };
4748
4749 TF_ASSIGN_OR_RETURN(
4750 TilingKernelInfo tiling_kernel_info,
4751 EmitTilingKernel(
4752 tiling_scheme, index_ty,
4753 [&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index,
4754 ValueVector3 tile_dimensions) {
4755 EmitTile(codegen_state.GetTilingScheme(), index, thread_id_info,
4756 tile_dimensions, emit_reduction_element);
4757 }));
4758
4759 KernelSupportLibrary ksl(&b_);
4760 for (const HloReduceInstruction* reduce : reductions) {
4761 for (int partial_result_idx = 0;
4762 partial_result_idx < reduction_info.GetNumPartialResults();
4763 ++partial_result_idx) {
4764 if (codegen_state.IsRowReduction()) {
4765 EmitReductionOutputForRowReduction(tiling_kernel_info, codegen_state,
4766 index_ty, result_ir_arrays, reduce,
4767 partial_result_idx);
4768 } else {
4769 EmitReductionOutputForColumnReduction(tiling_kernel_info, codegen_state,
4770 index_ty, result_ir_arrays,
4771 reduce, partial_result_idx);
4772 }
4773 }
4774 }
4775
4776 return OkStatus();
4777 }
4778
4779 namespace {
4780
4781 // Returns whether the `instr` is either a constant, a scalar, or a
4782 // broadcasted constant/scalar.
IsBroadcastedConstantOrScalar(const HloInstruction & instr)4783 bool IsBroadcastedConstantOrScalar(const HloInstruction& instr) {
4784 return instr.IsConstant() || ShapeUtil::IsScalar(instr.shape()) ||
4785 (HloOpcode::kBroadcast == instr.opcode() &&
4786 (instr.operand(0)->IsConstant() ||
4787 ShapeUtil::IsScalar(instr.operand(0)->shape())));
4788 }
4789
4790 // Recursive helper for GetFusionRoots below.
GetFusionRootsRec(HloInstruction * root,std::vector<HloInstruction * > & out)4791 static void GetFusionRootsRec(HloInstruction* root,
4792 std::vector<HloInstruction*>& out) {
4793 if (root->opcode() == HloOpcode::kGetTupleElement) {
4794 return GetFusionRootsRec(root->mutable_operand(0), out);
4795 } else if (root->opcode() == HloOpcode::kTuple) {
4796 for (int i = 0; i < root->operand_count(); i++) {
4797 GetFusionRootsRec(root->mutable_operand(i), out);
4798 }
4799 } else {
4800 if (!out.empty() && out.back() == root) {
4801 return;
4802 }
4803 CHECK(!absl::c_linear_search(out, root))
4804 << "Fusion root contains instruction " << root->ToString()
4805 << " multiple times";
4806 out.push_back(root);
4807 }
4808 }
4809
4810 // Returns instructions which are roots of the fusion, following the operands of
4811 // GTE instructions in the root tuple. Groups multiple subsequent instruction
4812 // with the same root. CHECKs that the fusion never outputs the same instruction
4813 // twice, as well as that there are no explicitly created tuples or nested gtes
4814 // in fusion output.
4815 //
4816 // For input: (tuple (gte R1) (gte R1) O2)
4817 // Expected output: [R1, O2]
4818 //
4819 // For input: (tuple R1 R2 O2)
4820 // Expected output: [R1, R2, O2]
4821 //
4822 // For input: (tuple (gte R1) (gte R1) R2 O3)
4823 // Expected output: [R1, R2, O3]
4824 //
4825 // For input: R1
4826 // Expected output: [R1]
GetFusionRoots(HloComputation * computation)4827 static std::vector<HloInstruction*> GetFusionRoots(
4828 HloComputation* computation) {
4829 std::vector<HloInstruction*> out;
4830 GetFusionRootsRec(computation->root_instruction(), out);
4831 return out;
4832 }
4833
4834 // Divides `num_reduces` reduces into groups. Different groups will be executed
4835 // in parallel. Generally speaking, we'd like to run the reduce instructions
4836 // in parallel without incurring too much recomputation overhead. The current
4837 // heuristic is to place reduce instructions who share nothing or only
4838 // (broadcasted) scalars/constants into different groups; otherwise, they are
4839 // placed in the same group. Non-reduce instructions always go with the reduce
4840 // instructions into the same group so long as they share any predecessors.
GroupDisjointReductions(HloComputation * fused_computation)4841 std::vector<std::vector<HloInstruction*>> GroupDisjointReductions(
4842 HloComputation* fused_computation) {
4843 const Shape& root_shape = fused_computation->root_instruction()->shape();
4844 int num_fusion_outputs =
4845 fused_computation->root_instruction()->opcode() == HloOpcode::kTuple
4846 ? root_shape.tuple_shapes_size()
4847 : 1;
4848 CHECK_NE(0, num_fusion_outputs);
4849 if (num_fusion_outputs == 1) {
4850 return {{fused_computation->root_instruction()}};
4851 }
4852
4853 std::vector<HloInstruction*> roots = GetFusionRoots(fused_computation);
4854 HloInstructionMap<tensorflow::UnionFind<HloInstruction*>> disjoint_sets;
4855
4856 for (HloInstruction* root : roots) {
4857 disjoint_sets[root].Get() = root;
4858 }
4859
4860 std::unique_ptr<HloReachabilityMap> reachability_map =
4861 HloReachabilityMap::Build(fused_computation);
4862 for (HloInstruction* instr : fused_computation->instructions()) {
4863 std::vector<HloInstruction*> reached_output_ids;
4864 bool added_to_reduce = false;
4865 for (HloInstruction* output : roots) {
4866 if (HloOpcode::kReduce == output->opcode() &&
4867 (IsBroadcastedConstantOrScalar(*instr))) {
4868 if (added_to_reduce) {
4869 // Do not group more than one output reduce instructions through
4870 // broadcasted constants or scalars, as the recomputation should be
4871 // acceptable.
4872 VLOG(3) << "Skip broadcasted constant or scalar "
4873 << instr->ToString();
4874 continue;
4875 }
4876 }
4877 // Now group output instructions if they have common predecessors.
4878 if (reachability_map->IsReachable(instr, output)) {
4879 VLOG(3) << "Reaching " << output->ToString() << " from "
4880 << instr->ToString();
4881 reached_output_ids.push_back(output);
4882 if (HloOpcode::kReduce == output->opcode()) {
4883 added_to_reduce = true;
4884 }
4885 }
4886 }
4887 for (size_t j = 1; j < reached_output_ids.size(); ++j) {
4888 disjoint_sets[reached_output_ids[0]].Merge(
4889 &disjoint_sets[reached_output_ids[j]]);
4890 }
4891 }
4892
4893 // Place output instructions in the same set into the same group.
4894 HloInstructionMap<std::vector<HloInstruction*>> groups;
4895 for (HloInstruction* root : roots) {
4896 groups[disjoint_sets[root].Get()].push_back(root);
4897 }
4898
4899 std::vector<std::vector<HloInstruction*>> ret;
4900 absl::c_for_each(
4901 groups, [&](auto& iter) { ret.emplace_back(std::move(iter.second)); });
4902 return ret;
4903 }
4904
4905 } // namespace
4906
EmitUnnestedReduction(mlir::lmhlo::FusionOp fusion)4907 Status IrEmitterUnnested::EmitUnnestedReduction(mlir::lmhlo::FusionOp fusion) {
4908 llvm::SmallVector<mlir::Operation*> fusion_roots = fusion.getFusionRoots();
4909
4910 TF_ASSIGN_OR_RETURN(HloComputation * fused_computation,
4911 GetOrCreateSubComputationFromRegion(&fusion.getRegion(),
4912 /*is_fusion=*/true));
4913
4914 // Group disjoint reductions in groups, to be executed in parallel.
4915 std::vector<std::vector<HloInstruction*>> instr_index_groups =
4916 GroupDisjointReductions(fused_computation);
4917
4918 VLOG(2) << StrCat("Generate in ", instr_index_groups.size(), " groups for ",
4919 MlirToString(fusion));
4920
4921 mlir::mhlo::ReduceOp first_reduce = mlir::cast<mlir::mhlo::ReduceOp>(
4922 *absl::c_find_if(fusion_roots, [](mlir::Operation* op) {
4923 return IsReductionFromOrToContiguousDimensions(op);
4924 }));
4925 TF_ASSIGN_OR_RETURN(ReductionCodegenInfo reduction_info,
4926 ComputeReductionCodegenInfo(fusion, first_reduce));
4927 const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme();
4928
4929 // block_y_count is set to instr_index_groups.size(), so that each reduction
4930 // group can be run in parallel by a different BlockIdy.
4931 LaunchDimensions launch_dimensions(
4932 {/*x=*/tiling_scheme.GetNumberOfBlocksPhysical(),
4933 /*y=*/static_cast<int64_t>(instr_index_groups.size()),
4934 /*z=*/1},
4935 {/*x=*/tiling_scheme.GetNumThreadsPerBlockPhysical(), /*y=*/1, /*z=*/1});
4936 VLOG(3) << "Launch dimensions of " << mlir::GetNameFromLoc(fusion.getLoc())
4937 << launch_dimensions.ToString();
4938
4939 std::vector<llvm_ir::IrArray> ir_arrays;
4940 TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> kernel_thunk,
4941 BuildKernelThunk(fusion, Thunk::ThunkInfo(), &ir_arrays,
4942 launch_dimensions));
4943
4944 GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
4945 GetNestedComputer());
4946 FusedIrEmitter fused_emitter(elemental_emitter);
4947 CHECK_LT(fused_computation->num_parameters(), ir_arrays.size());
4948 for (int i = 0; i < fused_computation->num_parameters(); i++) {
4949 llvm_ir::IrArray ir_array = ir_arrays[i];
4950 HloInstruction* fused_operand = fused_computation->parameter_instruction(i);
4951 fused_emitter.BindGenerator(
4952 *fused_operand,
4953 [this, ir_array, fused_operand](const llvm_ir::IrArray::Index& index) {
4954 return ir_array.EmitReadArrayElement(index, &b_,
4955 fused_operand->name());
4956 });
4957 }
4958
4959 // Get outputs.
4960 ReductionOutputMap result_ir_arrays;
4961
4962 // Skip all parameter buffers first.
4963 int ir_arrays_idx = fused_computation->num_parameters();
4964 std::vector<HloInstruction*> roots = GetFusionRoots(fused_computation);
4965 for (HloInstruction* root : roots) {
4966 int get_num_results = GetNumOutputs(root->shape());
4967 result_ir_arrays[root] =
4968 absl::MakeSpan(ir_arrays).subspan(ir_arrays_idx, get_num_results);
4969 ir_arrays_idx += get_num_results;
4970 }
4971
4972 // We always use the first reduce as representative to construct
4973 // ReductionCodegenInfo, since all the reductions are required to have the
4974 // same shape and layout as verified by `IsFusedReductionOutputConsistent()`.
4975 TF_ASSIGN_OR_RETURN(ReductionCodegenInfo reduction_codegen_info,
4976 ComputeReductionCodegenInfo(fusion, first_reduce));
4977
4978 KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
4979
4980 // Use raw block_id_y to select the i-th parallel reduction to run. Using
4981 // block_id_y instead of block_id_x simplifies the index calculation
4982 // for reduction code generation as the block_id_y is orthogonal to
4983 // the indices used within the reductions.
4984 llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic(
4985 gpu::TargetIntrinsicID::kBlockIdy, {}, {}, &b_);
4986 llvm_ir::AddRangeMetadata(0, instr_index_groups.size(),
4987 llvm::cast<llvm::Instruction>(raw_block_id_y));
4988 for (int i = 0; i < instr_index_groups.size(); ++i) {
4989 TF_RETURN_IF_ERROR(ksl.IfWithStatus(
4990 StrCat("reduce-group-", i),
4991 b_.CreateICmpEQ(raw_block_id_y, b_.getInt32(i)), [&] {
4992 return EmitIRForReduction(
4993 fusion, instr_index_groups[i], fused_emitter, result_ir_arrays,
4994 reduction_codegen_info, GetShape(first_reduce->getOperand(0)));
4995 }));
4996 }
4997
4998 ThunkSequence thunks;
4999 if (!reduction_codegen_info.IsRaceFree()) {
5000 for (int i = 0; i < fusion_roots.size(); ++i) {
5001 mlir::Operation* output_instruction = fusion_roots[i];
5002 if (IsReductionFromOrToContiguousDimensions(output_instruction)) {
5003 TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
5004 BuildFusedInitializerThunk(fusion, i));
5005 thunks.push_back(std::move(initializer_thunk));
5006 }
5007 }
5008 }
5009
5010 thunks.push_back(std::move(kernel_thunk));
5011 auto sequential_thunk = std::make_unique<SequentialThunk>(
5012 GetThunkInfo(fusion), std::move(thunks));
5013 AddThunkToThunkSequence(std::move(sequential_thunk));
5014
5015 return OkStatus();
5016 }
5017
5018 // Emits code for slices based on the below structure. An if statement with
5019 // a guarding condition is generated for each ROOT slice.
5020 //
5021 // Pseudo code:
5022 //
5023 // Compute values of slice input operands
5024 //
5025 // Compute guarding_cond0
5026 // if (guarding_cond0) {
5027 // Write to output of slice0
5028 // }
5029 //
5030 // Compute guarding_cond1
5031 // if (guarding_cond1) {
5032 // Write to output of slice1
5033 // }
5034 //
EmitElementForInputFusibleSlices(const HloComputation * fused_computation,absl::Span<const llvm_ir::IrArray> ir_arrays,const llvm_ir::IrArray::Index & index)5035 Status IrEmitterUnnested::EmitElementForInputFusibleSlices(
5036 const HloComputation* fused_computation,
5037 absl::Span<const llvm_ir::IrArray> ir_arrays,
5038 const llvm_ir::IrArray::Index& index) {
5039 VLOG(10) << "Emitting slice input fusion for "
5040 << fused_computation->ToString();
5041
5042 HloInstruction* slice_or_tuple = fused_computation->root_instruction();
5043 auto slice_instructions = [&]() -> absl::Span<HloInstruction* const> {
5044 if (slice_or_tuple->opcode() == HloOpcode::kSlice) {
5045 return absl::Span<HloInstruction* const>(&slice_or_tuple, 1);
5046 }
5047 CHECK_EQ(slice_or_tuple->opcode(), HloOpcode::kTuple);
5048 return slice_or_tuple->operands();
5049 }();
5050
5051 // Emit input operand values of slices.
5052 std::vector<llvm::Value*> input_ir_values;
5053 GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
5054 GetNestedComputer());
5055 FusedIrEmitter fused_emitter(elem_emitter);
5056 for (int i = 0; i < fused_computation->num_parameters(); i++) {
5057 fused_emitter.BindGenerator(
5058 *fused_computation->parameter_instruction(i),
5059 [this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
5060 return ir_arrays[i].EmitReadArrayElement(index, &b_);
5061 });
5062 }
5063 for (const HloInstruction* slice : slice_instructions) {
5064 auto input_generator = *fused_emitter.GetGenerator(*slice->operand(0));
5065 input_ir_values.push_back(input_generator(index).ValueOrDie());
5066 }
5067
5068 // Emit for slice_instructions.
5069 KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
5070 for (int64_t i = 0; i < slice_instructions.size(); ++i) {
5071 HloInstruction* slice = slice_instructions[i];
5072
5073 // guarding_cond := index >= start && index < limit, for each dim.
5074 std::vector<llvm::Value*> index_within_ranges;
5075 for (size_t dim = 0; dim < slice->slice_starts().size(); ++dim) {
5076 CHECK_EQ(slice->slice_strides(dim), 1);
5077 auto larger_or_equal_than_start = b_.CreateICmpSGE(
5078 index.multidim()[dim],
5079 index.GetConstantWithIndexType(slice->slice_starts(dim)));
5080 llvm::Value* smaller_than_limit = b_.CreateICmpSLT(
5081 index.multidim()[dim],
5082 index.GetConstantWithIndexType(slice->slice_limits(dim)));
5083 llvm::Value* within_range =
5084 b_.CreateAnd(larger_or_equal_than_start, smaller_than_limit);
5085 index_within_ranges.push_back(within_range);
5086 }
5087 llvm::Value* guarding_cond = b_.CreateAnd(index_within_ranges);
5088
5089 auto emit_slice_elem_func = [&] {
5090 const std::vector<llvm::Value*>& src_multidim = index.multidim();
5091 std::vector<llvm::Value*> dst_multidim(src_multidim.size());
5092 for (size_t dim = 0; dim < src_multidim.size(); ++dim) {
5093 dst_multidim[dim] =
5094 Sub(src_multidim[dim],
5095 index.GetConstantWithIndexType(slice->slice_starts(dim)));
5096 }
5097 llvm_ir::IrArray src_ir_array =
5098 ir_arrays[fused_computation->num_parameters() + i];
5099 IrArray::Index slice_dst_index(dst_multidim, slice->shape(),
5100 index.GetType());
5101 src_ir_array.EmitWriteArrayElement(slice_dst_index, input_ir_values[i],
5102 &b_);
5103 };
5104
5105 ksl.If(StrCat("slice", i), guarding_cond, emit_slice_elem_func);
5106 }
5107 return OkStatus();
5108 }
5109
EmitInputFusibleNonStridedSlices(mlir::Operation * op)5110 Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices(
5111 mlir::Operation* op) {
5112 auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(op);
5113
5114 constexpr int unroll_factor = 1;
5115
5116 TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation,
5117 GetOrCreateSubComputationFromRegion(&fusion.getRegion(),
5118 /*is_fusion=*/true));
5119
5120 TF_ASSIGN_OR_RETURN(Shape element_shape,
5121 GetConsistentInputShapeForRootSlices(fused_computation));
5122 TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
5123 CalculateLaunchDimensions(
5124 element_shape, ir_emitter_context_->gpu_device_info(),
5125 {unroll_factor}));
5126
5127 std::vector<llvm_ir::IrArray> ir_arrays;
5128 TF_ASSIGN_OR_RETURN(auto kernel_thunk,
5129 BuildKernelThunk(fusion, GetThunkInfo(op), &ir_arrays,
5130 launch_dimensions));
5131
5132 Status emit_status =
5133 ParallelLoopEmitter(
5134 [&](const llvm_ir::IrArray::Index index) -> Status {
5135 return EmitElementForInputFusibleSlices(fused_computation,
5136 ir_arrays, index);
5137 },
5138 element_shape, launch_dimensions, &b_)
5139 .EmitLoop(IrName(GetIrNameFromLoc(fusion.getLoc())),
5140 GetIndexTypeForKernel(
5141 fusion, launch_dimensions.launch_bound(), &b_));
5142
5143 thunk_sequence_.emplace_back(std::move(kernel_thunk));
5144
5145 return emit_status;
5146 }
5147
EmitDynamicUpdateSlice(mlir::lmhlo::FusionOp fusion_op,const HloComputation * fused_computation)5148 Status IrEmitterUnnested::EmitDynamicUpdateSlice(
5149 mlir::lmhlo::FusionOp fusion_op, const HloComputation* fused_computation) {
5150 // Fusion node with dynamic-update-slice as the root where the op's input
5151 // (i.e. array to update) shares the same slice as its output. In this case
5152 // we have a special algorithm that modifies the output in place without
5153 // touching the un-updated elements.
5154 CHECK_EQ(1, GetHloOutputs(fusion_op).size());
5155
5156 // Shape of the dynamic-update-slice's "update" operand.
5157 Shape update_shape =
5158 fused_computation->root_instruction()->operand(1)->shape();
5159
5160 TF_ASSIGN_OR_RETURN(
5161 LaunchDimensions launch_dimensions,
5162 CalculateLaunchDimensions(update_shape,
5163 ir_emitter_context_->gpu_device_info()));
5164
5165 // Set up kernel thunk and fused ir emitter.
5166 std::vector<llvm_ir::IrArray> ir_arrays;
5167 TF_ASSIGN_OR_RETURN(auto fusion_thunk,
5168 BuildKernelThunk(fusion_op, GetThunkInfo(fusion_op),
5169 &ir_arrays, launch_dimensions));
5170 AddThunkToThunkSequence(std::move(fusion_thunk));
5171
5172 GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
5173 GetNestedComputer());
5174
5175 FusedIrEmitter fused_emitter(elemental_emitter);
5176
5177 for (int i = 0; i < fused_computation->num_parameters(); i++) {
5178 auto fused_operand = fused_computation->parameter_instruction(i);
5179 fused_emitter.BindGenerator(
5180 *fused_operand, [this, &ir_arrays, i,
5181 fused_operand](const llvm_ir::IrArray::Index& index) {
5182 return ir_arrays[i].EmitReadArrayElement(index, &b_,
5183 fused_operand->name());
5184 });
5185 }
5186
5187 // Array to write into. Because this is an in-place operation, this is the
5188 // same as operand 0's array.
5189 const IrArray& output_array = ir_arrays.back();
5190
5191 return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace(
5192 fused_computation, output_array, &fused_emitter, launch_dimensions, &b_);
5193 }
5194
EmitScatter(mlir::lmhlo::FusionOp fusion_op,const HloComputation * fused_computation)5195 Status IrEmitterUnnested::EmitScatter(mlir::lmhlo::FusionOp fusion_op,
5196 const HloComputation* fused_computation) {
5197 auto* root = fused_computation->root_instruction();
5198
5199 ThunkSequence thunks;
5200 // The initialization from 'operand' is using different loop bounds, so
5201 // emit it in a separate kernel. Treat it like a loop fusion, writing to
5202 // the output buffer.
5203 {
5204 auto unroll_factor = ComputeMaxUnrollFactor(fusion_op, hlo_module_config_);
5205 const Shape& element_shape = root->shape();
5206 TF_ASSIGN_OR_RETURN(
5207 LaunchDimensions launch_dimensions,
5208 CalculateLaunchDimensions(element_shape,
5209 ir_emitter_context_->gpu_device_info(),
5210 {unroll_factor, /*few_waves=*/false}));
5211
5212 std::vector<llvm_ir::IrArray> ir_arrays;
5213 TF_ASSIGN_OR_RETURN(auto operand_thunk,
5214 BuildKernelThunk(fusion_op, Thunk::ThunkInfo(),
5215 &ir_arrays, launch_dimensions));
5216 thunks.push_back(std::move(operand_thunk));
5217
5218 GpuElementalIrEmitter operand_elemental_emitter(hlo_module_config_, module_,
5219 &b_, GetNestedComputer());
5220 FusedIrEmitter operand_fused_emitter(operand_elemental_emitter);
5221 for (int i = 0; i < fused_computation->num_parameters(); i++) {
5222 auto fused_operand = fused_computation->parameter_instruction(i);
5223 operand_fused_emitter.BindGenerator(
5224 *fused_operand,
5225 [this, &ir_arrays, i, fused_operand](llvm_ir::IrArray::Index index) {
5226 return ir_arrays[i].EmitReadArrayElement(index, &b_,
5227 fused_operand->name());
5228 });
5229 }
5230 TF_ASSIGN_OR_RETURN(auto generator,
5231 operand_fused_emitter.GetGenerator(*root->operand(0)));
5232
5233 TF_RETURN_IF_ERROR(
5234 ParallelLoopEmitter(generator, {ir_arrays.back()}, launch_dimensions,
5235 &b_, {unroll_factor})
5236 .EmitLoop(IrName(GetIrNameFromLoc(fusion_op.getLoc())),
5237 GetIndexTypeForKernel(
5238 fusion_op, launch_dimensions.launch_bound(), &b_)));
5239 }
5240
5241 // Now build the actual scatter, reading and writing to the freshly
5242 // filled output buffer.
5243 {
5244 const Shape& updates_shape = root->operand(2)->shape();
5245 TF_ASSIGN_OR_RETURN(
5246 LaunchDimensions launch_dimensions,
5247 CalculateLaunchDimensions(updates_shape,
5248 ir_emitter_context_->gpu_device_info()));
5249 std::vector<llvm_ir::IrArray> ir_arrays;
5250 TF_ASSIGN_OR_RETURN(auto scatter_thunk,
5251 BuildKernelThunk(fusion_op, Thunk::ThunkInfo(),
5252 &ir_arrays, launch_dimensions));
5253 thunks.push_back(std::move(scatter_thunk));
5254 // Spin up a new fused emitter for the scatter kernel and emit it.
5255 GpuElementalIrEmitter scatter_elemental_emitter(hlo_module_config_, module_,
5256 &b_, GetNestedComputer());
5257 FusedIrEmitter scatter_fused_emitter(scatter_elemental_emitter);
5258 for (int i = 0; i < fused_computation->num_parameters(); i++) {
5259 auto fused_operand = fused_computation->parameter_instruction(i);
5260 scatter_fused_emitter.BindGenerator(
5261 *fused_operand,
5262 [this, &ir_arrays, i, fused_operand](llvm_ir::IrArray::Index index) {
5263 return ir_arrays[i].EmitReadArrayElement(index, &b_,
5264 fused_operand->name());
5265 });
5266 }
5267
5268 TF_ASSIGN_OR_RETURN(const auto dim_numbers,
5269 mlir::LhloDialectEmitter::GetScatterDimensionNumbers(
5270 root, fusion_op.getContext()));
5271
5272 ScatterDescriptor desc;
5273 desc.name = IrName(root);
5274 desc.operand_shape = root->operand(0)->shape();
5275 desc.scatter_indices_shape = root->operand(1)->shape();
5276 desc.updates_shape = updates_shape;
5277 desc.dim_numbers = dim_numbers;
5278 desc.unique_indices = root->unique_indices();
5279 desc.update_computation = root->called_computations()[0];
5280 desc.output = ir_arrays.back();
5281 TF_ASSIGN_OR_RETURN(desc.scatter_indices_gen,
5282 scatter_fused_emitter.GetGenerator(*root->operand(1)));
5283 TF_ASSIGN_OR_RETURN(desc.updates_gen,
5284 scatter_fused_emitter.GetGenerator(*root->operand(2)));
5285 desc.get_index_type = [&](int64_t launch_size) {
5286 return GetIndexTypeForKernel(root, launch_size, &b_);
5287 };
5288
5289 TF_RETURN_IF_ERROR(
5290 EmitScatter(desc, thunks.back().get(), launch_dimensions));
5291 }
5292 AddThunkToThunkSequence(std::make_unique<SequentialThunk>(
5293 GetThunkInfo(fusion_op), std::move(thunks)));
5294 return OkStatus();
5295 }
5296
EmitOp(mlir::Operation * op)5297 Status IrEmitterUnnested::EmitOp(mlir::Operation* op) {
5298 if (mlir::isa<mlir::func::ConstantOp, mlir::arith::ConstantOp,
5299 mlir::memref::ViewOp, mlir::memref::ReinterpretCastOp,
5300 mlir::func::ReturnOp, mlir::lmhlo::TerminatorOp>(op)) {
5301 return OkStatus();
5302 }
5303
5304 if (mlir::isa<mlir::memref::GetGlobalOp>(op)) {
5305 return EmitConstant(op);
5306 }
5307
5308 if (auto call = mlir::dyn_cast<mlir::lmhlo::CustomCallOp>(op)) {
5309 if (call.getCallTargetName() == "PadToStatic") {
5310 return EmitPadToStatic(op);
5311 }
5312 if (call.getCallTargetName() == "SliceToDynamic") {
5313 return EmitSliceToDynamic(op);
5314 }
5315 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
5316 if (call.getCallTargetName() == kTriangularSolveCallTarget) {
5317 return EmitTriangularSolveCustomCall(op);
5318 }
5319 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
5320
5321 return EmitCustomCallThunk(op);
5322 }
5323
5324 if (mlir::isa<mlir::lmhlo_gpu::GEMMOp>(op)) {
5325 return EmitGemmThunk(op);
5326 }
5327
5328 #if GOOGLE_CUDA
5329 if (mlir::isa<mlir::lmhlo_gpu::CublasLtMatmulOp>(op)) {
5330 return EmitCublasLtMatmulThunk(op);
5331 }
5332 #endif // GOOGLE_CUDA
5333
5334 if (mlir::isa<mlir::lmhlo_gpu::ConvForwardOp,
5335 mlir::lmhlo_gpu::ConvForwardFusedOp,
5336 mlir::lmhlo_gpu::ConvForwardFusedSideInputOp,
5337 mlir::lmhlo_gpu::ConvBackwardFilterOp,
5338 mlir::lmhlo_gpu::ConvBackwardInputOp>(op)) {
5339 return EmitConvolutionThunk(op);
5340 }
5341
5342 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
5343 if (mlir::isa<mlir::lmhlo_gpu::CholeskyOp>(op)) {
5344 return EmitCholeskyThunk(op);
5345 }
5346 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
5347
5348 if (mlir::isa<mlir::lmhlo::FftOp>(op)) {
5349 return EmitFftThunk(op);
5350 }
5351
5352 if (mlir::isa<mlir::lmhlo::TriangularSolveOp>(op)) {
5353 return InternalError(
5354 "TriangularSolve is implemented as a custom-call; we do not expect to "
5355 "lower a true HLO TriangularSolve op.");
5356 }
5357
5358 if (mlir::isa<mlir::lmhlo::FusionOp>(op)) {
5359 return EmitFusion(op);
5360 }
5361
5362 if (mlir::isa<mlir::lmhlo::SelectAndScatterOp>(op)) {
5363 return EmitSelectAndScatter(op);
5364 }
5365
5366 if (mlir::isa<mlir::lmhlo::RngGetAndUpdateStateOp>(op)) {
5367 return EmitRngGetAndUpdateState(op);
5368 }
5369
5370 if (mlir::isa<mlir::lmhlo::ScatterOp>(op)) {
5371 return EmitScatter(op);
5372 }
5373
5374 if (mlir::isa<mlir::lmhlo::SortOp>(op)) {
5375 return EmitSort(op);
5376 }
5377
5378 if (mlir::isa<mlir::lmhlo::ReplicaIdOp>(op)) {
5379 return EmitReplicaOrPartitionId<ReplicaIdThunk, mlir::lmhlo::ReplicaIdOp>(
5380 op);
5381 }
5382
5383 if (mlir::isa<mlir::lmhlo::PartitionIdOp>(op)) {
5384 return EmitReplicaOrPartitionId<PartitionIdThunk,
5385 mlir::lmhlo::PartitionIdOp>(op);
5386 }
5387
5388 if (mlir::isa<mlir::lmhlo::CollectivePermuteOp>(op)) {
5389 return EmitCollectivePermute(op);
5390 }
5391
5392 if (mlir::isa<mlir::lmhlo::AllGatherOp>(op)) {
5393 return EmitNcclThunk<NcclAllGatherThunk, mlir::lmhlo::AllGatherOp>(op);
5394 }
5395
5396 if (mlir::isa<mlir::lmhlo::AllReduceOp>(op)) {
5397 return EmitNcclThunk<NcclAllReduceThunk, mlir::lmhlo::AllReduceOp>(op);
5398 }
5399
5400 if (mlir::isa<mlir::lmhlo_gpu::AllReduceStartOp>(op)) {
5401 return EmitNcclThunk<NcclAllReduceStartThunk,
5402 mlir::lmhlo_gpu::AllReduceStartOp>(op);
5403 }
5404
5405 if (mlir::isa<mlir::lmhlo_gpu::AllReduceDoneOp>(op)) {
5406 return EmitAllReduceDone(op);
5407 }
5408
5409 if (mlir::isa<mlir::lmhlo::ReduceScatterOp>(op)) {
5410 return EmitNcclThunk<NcclReduceScatterThunk, mlir::lmhlo::ReduceScatterOp>(
5411 op);
5412 }
5413
5414 if (mlir::isa<mlir::lmhlo::AllToAllOp>(op)) {
5415 return EmitNcclThunk<NcclAllToAllThunk, mlir::lmhlo::AllToAllOp>(op);
5416 }
5417
5418 if (mlir::isa<mlir::lmhlo::InfeedOp>(op)) {
5419 return EmitInfeed(op);
5420 }
5421
5422 if (mlir::isa<mlir::lmhlo::OutfeedOp>(op)) {
5423 return EmitOutfeed(op);
5424 }
5425
5426 if (mlir::isa<mlir::lmhlo::CaseOp>(op)) {
5427 return EmitConditional(op);
5428 }
5429
5430 if (mlir::isa<mlir::lmhlo::WhileOp>(op)) {
5431 return EmitWhile(op);
5432 }
5433
5434 if (mlir::isa<mlir::gpu::LaunchFuncOp>(op)) {
5435 return EmitLaunchFunc(op);
5436 }
5437
5438 // Remaining arith.constant ops are the gpu.launch_func dimensions as a result
5439 // of inlining the fusion region after lowering. They can safely be skipped
5440 // because constants have no side effects.
5441 if (mlir::isa<mlir::arith::ConstantOp>(op)) {
5442 return Status::OK();
5443 }
5444
5445 return InternalError("Unrecognized op: %s", MlirToString(op));
5446 }
5447
EmitLmhloRegion(mlir::Region * region)5448 Status IrEmitterUnnested::EmitLmhloRegion(mlir::Region* region) {
5449 for (mlir::Operation& op : llvm::make_early_inc_range(region->front())) {
5450 TF_RETURN_IF_ERROR(EmitOp(&op));
5451 }
5452 return OkStatus();
5453 }
5454
GetDependentDialects(mlir::DialectRegistry & registry)5455 void IrEmitterUnnested::GetDependentDialects(mlir::DialectRegistry& registry) {
5456 registry.insert<mlir::arith::ArithmeticDialect, mlir::func::FuncDialect,
5457 mlir::gpu::GPUDialect, mlir::lmhlo::LmhloDialect,
5458 mlir::lmhlo_gpu::LmhloGpuDialect, mlir::mhlo::MhloDialect>();
5459 mlir::registerLLVMDialectTranslation(registry);
5460 mlir::registerNVVMDialectTranslation(registry);
5461 mlir::registerROCDLDialectTranslation(registry);
5462 }
5463
GetThunkInfo(mlir::Operation * op)5464 Thunk::ThunkInfo IrEmitterUnnested::GetThunkInfo(mlir::Operation* op) {
5465 auto module = op->getParentOfType<mlir::ModuleOp>();
5466 // Include the HloModule's unique_id in the thunk's module name so that xprof
5467 // shows different modules differently, addressing b/202415436#comment24.
5468 // xprof calls this the "program_id".
5469 std::string unique_id_str;
5470 if (auto unique_id_attr =
5471 module->getAttrOfType<mlir::IntegerAttr>("mhlo.unique_id")) {
5472 unique_id_str = absl::StrFormat(",program_id=%d",
5473 unique_id_attr.getValue().getZExtValue());
5474 }
5475 Thunk::ThunkInfo thunk_info;
5476 thunk_info.profile_annotation = absl::StrFormat(
5477 "Thunk:#hlo_op=%s,hlo_module=%s%s#", mlir::GetNameFromLoc(op->getLoc()),
5478 mlir::GetNameFromLoc(module->getLoc()), unique_id_str);
5479 return thunk_info;
5480 }
5481
5482 } // namespace gpu
5483 } // namespace xla
5484