xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /*Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <cstring>
21 #include <iterator>
22 #include <map>
23 #include <memory>
24 #include <numeric>
25 #include <optional>
26 #include <string>
27 #include <type_traits>
28 #include <utility>
29 #include <vector>
30 
31 #include "absl/algorithm/container.h"
32 #include "absl/container/flat_hash_set.h"
33 #include "absl/container/inlined_vector.h"
34 #include "absl/strings/str_cat.h"
35 #include "absl/strings/str_format.h"
36 #include "absl/types/span.h"
37 #include "llvm/ADT/APInt.h"
38 #include "llvm/ADT/StringRef.h"
39 #include "llvm/IR/BasicBlock.h"
40 #include "llvm/IR/Function.h"
41 #include "llvm/IR/IRBuilder.h"
42 #include "llvm/IR/Instructions.h"
43 #include "llvm/IR/LLVMContext.h"
44 #include "llvm/IR/Module.h"
45 #include "llvm/Linker/Linker.h"
46 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
47 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
48 #include "mlir/Dialect/GPU/IR/GPUDialect.h"  // from @llvm-project
49 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
50 #include "mlir/Dialect/MemRef/IR/MemRef.h"  // from @llvm-project
51 #include "mlir/IR/Attributes.h"  // from @llvm-project
52 #include "mlir/IR/Builders.h"  // from @llvm-project
53 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
54 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
55 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
56 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"  // from @llvm-project
57 #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"  // from @llvm-project
58 #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"  // from @llvm-project
59 #include "mlir/Target/LLVMIR/Export.h"  // from @llvm-project
60 #include "tensorflow/compiler/mlir/utils/name_utils.h"
61 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
62 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
63 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
64 #include "tensorflow/compiler/xla/layout_util.h"
65 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
66 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
67 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
68 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/gpu_passes.h"
69 #include "tensorflow/compiler/xla/permutation_util.h"
70 #include "tensorflow/compiler/xla/primitive_util.h"
71 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
72 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
73 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
74 #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
75 #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
76 #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
77 #include "tensorflow/compiler/xla/service/gpu/custom_call_thunk.h"
78 #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
79 #include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
80 #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
81 #include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h"
82 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
83 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
84 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
85 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
86 #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
87 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
88 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
89 #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h"
90 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
91 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
92 #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h"
93 #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h"
94 #include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
95 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
96 #include "tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h"
97 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.h"
98 #include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h"
99 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
100 #include "tensorflow/compiler/xla/service/gpu/replica_id_thunk.h"
101 #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
102 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
103 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
104 #include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
105 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
106 #include "tensorflow/compiler/xla/service/hlo_computation.h"
107 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
108 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
109 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
110 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
111 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
112 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
113 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
114 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
115 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
116 #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
117 #include "tensorflow/compiler/xla/service/name_uniquer.h"
118 #include "tensorflow/compiler/xla/shape.h"
119 #include "tensorflow/compiler/xla/shape_util.h"
120 #include "tensorflow/compiler/xla/status_macros.h"
121 #include "tensorflow/compiler/xla/union_find.h"
122 #include "tensorflow/compiler/xla/util.h"
123 #include "tensorflow/compiler/xla/xla_data.pb.h"
124 #include "tensorflow/core/platform/errors.h"
125 #include "tensorflow/core/platform/human_readable_json.h"
126 #include "tensorflow/core/platform/logging.h"
127 
128 #if GOOGLE_CUDA
129 #include "tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h"
130 #endif  // GOOGLE_CUDA
131 
132 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
133 #include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
134 #include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
135 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
136 
137 namespace xla {
138 namespace gpu {
139 
140 namespace {
141 
142 using absl::InlinedVector;
143 using absl::StrCat;
144 using llvm_ir::IrArray;
145 using llvm_ir::IrName;
146 using std::optional;
147 
148 const auto kDimX = TilingScheme::DimX;
149 const auto kDimY = TilingScheme::DimY;
150 const auto kDimZ = TilingScheme::DimZ;
151 const auto kDimTot = TilingScheme::DimTot;
152 
153 const auto kLinearIndexingX = TilingScheme::LinearIndexingX;
154 const auto kStridedIndexingX = TilingScheme::StridedIndexingX;
155 
AnnotateWithInt32Value(std::string name,int64_t value,const std::string & kernel_name,llvm::Module * llvm_module)156 void AnnotateWithInt32Value(std::string name, int64_t value,
157                             const std::string& kernel_name,
158                             llvm::Module* llvm_module) {
159   llvm::NamedMDNode* nvvm_annotations_node =
160       llvm_module->getOrInsertNamedMetadata("nvvm.annotations");
161   llvm::Function* ir_kernel = llvm_module->getFunction(kernel_name.c_str());
162   llvm::LLVMContext& llvm_context = llvm_module->getContext();
163 
164   nvvm_annotations_node->addOperand(llvm::MDNode::get(
165       llvm_context,
166       {llvm::ConstantAsMetadata::get(ir_kernel),
167        llvm::MDString::get(llvm_context, name),
168        llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(
169            llvm::IntegerType::get(llvm_context, /*NumBits=*/32), value))}));
170 }
171 
172 // Annotates the launch dimensions of the corresponding IR kernel in
173 // `llvm_module`.
AnnotateThunkLaunchDimensions(const LaunchDimensions & launch_dims,const std::string & kernel_name,llvm::Module * llvm_module)174 void AnnotateThunkLaunchDimensions(const LaunchDimensions& launch_dims,
175                                    const std::string& kernel_name,
176                                    llvm::Module* llvm_module) {
177   // Add __launch_bounds__ to metadata. This limits registers per thread to
178   // avoid out-of-resources launching errors.
179 
180   // Our launch bounds are exact, so we can specify them as
181   // reqntid[xyz] rather than maxntid[xyz].
182   AnnotateWithInt32Value("reqntidx", launch_dims.thread_counts_per_block().x,
183                          kernel_name, llvm_module);
184   if (launch_dims.thread_counts_per_block().y > 1) {
185     AnnotateWithInt32Value("reqntidy", launch_dims.thread_counts_per_block().y,
186                            kernel_name, llvm_module);
187   }
188   if (launch_dims.thread_counts_per_block().z > 1) {
189     AnnotateWithInt32Value("reqntidz", launch_dims.thread_counts_per_block().z,
190                            kernel_name, llvm_module);
191   }
192 }
193 
BinarySearchDenseElementsAttr(mlir::DenseIntElementsAttr elements,int64_t v)194 bool BinarySearchDenseElementsAttr(mlir::DenseIntElementsAttr elements,
195                                    int64_t v) {
196   mlir::APInt value(sizeof(int64_t) * 8, v, /*isSigned=*/true);
197   return std::binary_search(
198       elements.begin(), elements.end(), value,
199       [](const mlir::APInt& x, const mlir::APInt& y) { return x.slt(y); });
200 }
201 
MhloOpIsElementwise(mlir::Operation * op)202 bool MhloOpIsElementwise(mlir::Operation* op) {
203   CHECK(op->getDialect() ==
204         op->getContext()->getLoadedDialect<mlir::mhlo::MhloDialect>());
205   auto opcode = *MhloToHloOpcode(op);
206   if (HloInstruction::IsOpElementwise(opcode)) {
207     return true;
208   }
209   if (opcode == HloOpcode::kMap) {
210     int iota = 0;
211     for (const llvm::APInt& i :
212          mlir::cast<mlir::mhlo::MapOp>(op).dimensions()) {
213       if (i.getZExtValue() != iota) {
214         return false;
215       }
216       iota++;
217     }
218     return true;
219   }
220   // TODO(timshen): not sure about whether porting
221   // HloFusionInstruction::IsElementwiseImpl() is necessary. HandleFusion()
222   // doesn't use such information.
223   return false;
224 }
225 
IsSingleInstructionFusion(mlir::lmhlo::FusionOp fusion)226 bool IsSingleInstructionFusion(mlir::lmhlo::FusionOp fusion) {
227   int instruction_count = 0;
228   for (mlir::Operation& instr : fusion.getRegion().front()) {
229     if (mlir::isa<mlir::lmhlo::TerminatorOp, mlir::mhlo::ReturnOp,
230                   mlir::bufferization::ToTensorOp, mlir::memref::TensorStoreOp>(
231             &instr)) {
232       continue;
233     }
234     instruction_count++;
235   }
236   return instruction_count == 1;
237 }
238 
MayPreventVectorization(mlir::Operation * op)239 bool MayPreventVectorization(mlir::Operation* op) {
240   // An empirically chosen constant: unrolling concat with a large amount of
241   // arguments causes excessive register spilling.
242   static constexpr int kMaxConcatArgumentsForUnrolling = 10;
243 
244   auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(op);
245 
246   for (mlir::Operation& instr : fusion.getRegion().front()) {
247     if (mlir::isa<mlir::lmhlo::TerminatorOp, mlir::mhlo::ReturnOp,
248                   mlir::bufferization::ToTensorOp, mlir::memref::TensorStoreOp>(
249             &instr)) {
250       continue;
251     }
252 
253     CHECK(instr.getDialect() ==
254           instr.getContext()->getLoadedDialect<mlir::mhlo::MhloDialect>())
255         << MlirToString(op);
256     switch (*MhloToHloOpcode(&instr)) {
257       case HloOpcode::kReduceWindow:
258       case HloOpcode::kSort:
259       case HloOpcode::kDot:
260       case HloOpcode::kSin:
261       case HloOpcode::kCos:
262       case HloOpcode::kPower:
263       case HloOpcode::kAtan2:
264         return true;
265       case HloOpcode::kConcatenate:
266         if (instr.getOperands().size() > kMaxConcatArgumentsForUnrolling) {
267           return true;
268         }
269         break;
270       case HloOpcode::kReduce:
271         if (instr.getNumResults() > 1) {
272           return true;
273         }
274         break;
275       default:
276         break;
277     }
278   }
279   return false;
280 }
281 
282 // Computes the maximum valid unroll factor for a given instruction.
ComputeMaxUnrollFactor(mlir::Type type,const HloModuleConfig & hlo_module_config)283 int ComputeMaxUnrollFactor(mlir::Type type,
284                            const HloModuleConfig& hlo_module_config) {
285   int max_unroll_factor =
286       hlo_module_config.debug_options().xla_gpu_max_kernel_unroll_factor();
287 
288   // Find the largest possible power of two to unroll by.
289   // TODO(kramerb): Make this smarter.
290 
291   auto shaped_type = type.cast<mlir::ShapedType>();
292   int64_t num_elements = std::accumulate(
293       shaped_type.getShape().begin(), shaped_type.getShape().end(), int64_t{1},
294       std::multiplies<int64_t>());
295   for (int i = max_unroll_factor; i > 1; i /= 2) {
296     if (num_elements % i == 0) {
297       return i;
298     }
299   }
300 
301   // Cannot unroll.
302   return 1;
303 }
304 
305 // Computes the maximum valid unroll factor for a given instruction.
ComputeMaxUnrollFactor(mlir::Operation * op,const HloModuleConfig & hlo_module_config)306 int ComputeMaxUnrollFactor(mlir::Operation* op,
307                            const HloModuleConfig& hlo_module_config) {
308   mlir::Type element_shape = [&] {
309     if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
310       return fusion.getFusionRoots()[0]->getResult(0).getType();
311     }
312     return GetHloOutputs(op)[0].getType();
313   }();
314   return ComputeMaxUnrollFactor(element_shape, hlo_module_config);
315 }
316 
317 // Returns the llvm type for the indices used in the kernel that contains the
318 // hlo instruction. Such indices include the index for the parallel loop and
319 // the indices for the tensors accessed by the kernel. The return type is i32
320 // iff the following conditions are met:
321 //  . The launch_size of the kernel is within the range of i32.
322 //  . The sizes of all the tensors accessed within the kernel are within the
323 //    range of i32.
324 // Otherwise, the return type is i64.
GetIndexTypeForKernel(const HloInstruction * hlo,int64_t launch_size,llvm::IRBuilder<> * b)325 llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo,
326                                   int64_t launch_size, llvm::IRBuilder<>* b) {
327   // Find the unnested hlo instruction for which the kernel is generated for.
328   const HloInstruction* unnested_hlo = hlo;
329   const HloComputation* computation = hlo->parent();
330   if (computation->IsFusionComputation()) {
331     unnested_hlo = computation->FusionInstruction();
332   }
333 
334   auto shape_in_range = [&](const Shape& s) {
335     bool in_range = true;
336     ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape,
337                                       const ShapeIndex& /*index*/) {
338       if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) {
339         in_range = false;
340       }
341     });
342 
343     return in_range;
344   };
345 
346   llvm::Type* i64_ty = b->getInt64Ty();
347   // Check launch dimension
348   if (!IsInt32(launch_size)) {
349     return i64_ty;
350   }
351 
352   // Check the size of result tensors
353   if (!shape_in_range(unnested_hlo->shape())) {
354     return i64_ty;
355   }
356 
357   auto hlo_shape_in_range = [&](const HloInstruction* operand) -> bool {
358     return shape_in_range(operand->shape());
359   };
360 
361   // Check the size of input tensors
362   if (!absl::c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) {
363     return i64_ty;
364   }
365 
366   // Check the size of the internal result tensors
367   if (unnested_hlo->opcode() == HloOpcode::kFusion) {
368     if (!absl::c_all_of(
369             unnested_hlo->fused_instructions_computation()->instructions(),
370             hlo_shape_in_range)) {
371       return i64_ty;
372     }
373   }
374 
375   return b->getInt32Ty();
376 }
377 
378 // The same as GetIndexTypeForKernel, but works with MLIR ops.
GetIndexTypeForKernel(mlir::Operation * op,int64_t launch_size,llvm::IRBuilder<> * b)379 llvm::Type* GetIndexTypeForKernel(mlir::Operation* op, int64_t launch_size,
380                                   llvm::IRBuilder<>* b) {
381   auto shape_in_range = [&](const Shape& s) {
382     bool in_range = true;
383     ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape,
384                                       const ShapeIndex& /*index*/) {
385       if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) {
386         in_range = false;
387       }
388     });
389 
390     return in_range;
391   };
392 
393   llvm::Type* i64_ty = b->getInt64Ty();
394   // Check launch dimension
395   if (!IsInt32(launch_size)) {
396     return i64_ty;
397   }
398 
399   // Check the size of result tensors
400   for (auto result : GetHloOutputs(op)) {
401     if (!shape_in_range(GetShape(result))) {
402       return i64_ty;
403     }
404   }
405 
406   auto hlo_shape_in_range = [&](mlir::Value operand) -> bool {
407     return shape_in_range(GetShape(operand));
408   };
409 
410   // Check the size of input tensors
411   if (!absl::c_all_of(op->getOperands(), hlo_shape_in_range)) {
412     return i64_ty;
413   }
414 
415   // Check the size of the internal result tensors
416   if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
417     auto result = fusion.getRegion().walk([&](mlir::Operation* op) {
418       for (mlir::Value result : op->getResults()) {
419         if (!hlo_shape_in_range(result)) {
420           return mlir::WalkResult::interrupt();
421         }
422       }
423       return mlir::WalkResult::advance();
424     });
425     if (result.wasInterrupted()) {
426       return i64_ty;
427     }
428   }
429 
430   return b->getInt32Ty();
431 }
432 
433 // Gets the input shape of the ROOT slices, which will be used as the kernel
434 // launch dims. The slice input fusion requires the input shapes of the ROOT
435 // slices to be the same although the (slice) output shapes can be different.
436 //
437 // Returns the input shape of the ROOT slices if all the input shapes of ROOT
438 // slices are the same and the slices are non-strided. Otherwise, returns
439 // FailedPrecondition.
GetConsistentInputShapeForRootSlices(const HloComputation * fused_computation)440 StatusOr<Shape> GetConsistentInputShapeForRootSlices(
441     const HloComputation* fused_computation) {
442   const HloInstruction& root = *fused_computation->root_instruction();
443   if (root.opcode() == HloOpcode::kSlice) {
444     return root.operands()[0]->shape();
445   }
446 
447   CHECK_EQ(root.opcode(), HloOpcode::kTuple);
448   const Shape& first_slice_operand_shape =
449       root.operands()[0]->operands()[0]->shape();
450   for (size_t i = 1; i < root.operands().size(); ++i) {
451     const HloInstruction* slice = root.operands()[i];
452     const Shape& operand_shape = slice->operands()[0]->shape();
453     if (!ShapeUtil::EqualIgnoringElementType(first_slice_operand_shape,
454                                              operand_shape)) {
455       return FailedPrecondition(
456           "Fused slices do not have the same input shape, fused computation = "
457           "%s.",
458           root.parent()->name());
459     }
460   }
461 
462   return first_slice_operand_shape;
463 }
464 
465 // Returns a sanitized (doesn't need quoting) identifier name from a location.
GetIrNameFromLoc(mlir::Location loc)466 std::string GetIrNameFromLoc(mlir::Location loc) {
467   return llvm_ir::SanitizeConstantName(mlir::GetNameFromLoc(loc));
468 }
469 
470 // For a row reduction, returns the number of rows we can process in parallel
471 // per warp.
RowReductionGetRowsPerWarp(int reduced_dimension_size)472 int RowReductionGetRowsPerWarp(int reduced_dimension_size) {
473   if (WarpSize() % reduced_dimension_size != 0 ||
474       reduced_dimension_size >= WarpSize()) {
475     return 1;
476   }
477   return WarpSize() / reduced_dimension_size;
478 }
479 
480 }  // namespace
481 
IrEmitterUnnested(const HloModuleConfig & hlo_module_config,IrEmitterContext * ir_emitter_context)482 IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
483                                      IrEmitterContext* ir_emitter_context)
484     : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false) {}
485 
Create(const HloModuleConfig & hlo_module_config,IrEmitterContext * ir_emitter_context)486 StatusOr<std::unique_ptr<IrEmitterUnnested>> IrEmitterUnnested::Create(
487     const HloModuleConfig& hlo_module_config,
488     IrEmitterContext* ir_emitter_context) {
489   return std::unique_ptr<IrEmitterUnnested>(
490       new IrEmitterUnnested(hlo_module_config, ir_emitter_context));
491 }
492 
BuildKernelPrototype(absl::string_view name,absl::Span<const BufferAllocation * const> args)493 llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
494     absl::string_view name, absl::Span<const BufferAllocation* const> args) {
495   // Compute the kernel name. The opcode string may contain "-" which cannot be
496   // in a PTX function name, so sanitize the name before uniquifying it.
497   std::string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName(
498       llvm_ir::SanitizeFunctionName(std::string(name)));
499 
500   // Create the kernel and add it to the module.
501   llvm::LLVMContext& context = module_->getContext();
502   llvm::FunctionType* kernel_type = llvm::FunctionType::get(
503       /*Result=*/llvm::Type::getVoidTy(context),
504       std::vector<llvm::Type*>(args.size(), b_.getInt8PtrTy()),
505       /*isVarArg=*/false);
506   llvm::Function* kernel =
507       llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage,
508                              kernel_name.c_str(), module_);
509 
510   // Add dereferenceable and alignment information to each of the kernel's
511   // parameters.
512   auto arg_it = kernel->arg_begin();
513   for (size_t arg_no = 0; arg_no < args.size(); ++arg_no) {
514     const BufferAllocation* alloc = args[arg_no];
515     llvm::Argument& fn_arg = *arg_it;
516     ++arg_it;
517 
518     kernel->addDereferenceableParamAttr(arg_no, alloc->size());
519 
520     const int64_t alignment = [&] {
521       if (alloc->is_entry_computation_parameter()) {
522         return kEntryParameterAlignBytes;
523       } else if (alloc->is_constant()) {
524         return kConstantBufferAlignBytes;
525       } else {
526         return kXlaAllocatedBufferAlignBytes;
527       }
528     }();
529 
530     kernel->addParamAttr(
531         arg_no,
532         llvm::Attribute::get(context, llvm::Attribute::Alignment, alignment));
533 
534     if (alloc->IsPreallocatedTempBuffer()) {
535       fn_arg.setName("temp_buf");
536     } else {
537       fn_arg.setName(StrCat("alloc", alloc->index()));
538     }
539   }
540 
541   AnnotateFunctionAsGpuKernel(module_, kernel, &b_);
542 
543   // TODO(b/65380986): Investigate if adding fast math flags for generated
544   // kernels makes sense.
545 
546   // Update the insert point to the entry basic block.
547   llvm::BasicBlock* entry_bb =
548       llvm::BasicBlock::Create(context, /*Name=*/"entry", /*Parent=*/kernel);
549 
550   // Emit a "return void" at entry_bb's end, and set the insert point before
551   // that return instruction.
552   b_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb));
553 
554   return kernel;
555 }
556 
GetAllocationSlice(mlir::Value v,std::string * constant_name)557 StatusOr<BufferAllocation::Slice> IrEmitterUnnested::GetAllocationSlice(
558     mlir::Value v, std::string* constant_name) {
559   return xla::gpu::GetAllocationSlice(v, ir_emitter_context_->allocations(),
560                                       constant_name);
561 }
562 
EmitConstant(mlir::Operation * op)563 Status IrEmitterUnnested::EmitConstant(mlir::Operation* op) {
564   auto get_global = mlir::cast<mlir::memref::GetGlobalOp>(op);
565   auto module = get_global->getParentOfType<mlir::ModuleOp>();
566   auto global = mlir::cast<mlir::memref::GlobalOp>(
567       module.lookupSymbol(get_global.getName()));
568 
569   auto literal = global.getInitialValue()->dyn_cast<mlir::DenseElementsAttr>();
570   TF_RET_CHECK(literal);
571 
572   const bool should_emit_initializer = literal.getType().getNumElements() <= 1;
573 
574   TF_ASSIGN_OR_RETURN(int element_bytes,
575                       GetElementTypeBytes(literal.getType().getElementType()));
576   llvm::ArrayType* global_type = llvm::ArrayType::get(
577       b_.getInt8Ty(), literal.getType().getNumElements() * element_bytes);
578 
579   GpuExecutable::ConstantInfo info;
580   llvm::Constant* initializer;
581   if (should_emit_initializer) {
582     std::vector<uint8_t> content;
583     TF_RETURN_IF_ERROR(CopyDenseElementsDataToXlaFormat(literal, &content));
584     initializer =
585         llvm::ConstantDataArray::get<uint8_t>(module_->getContext(), content);
586   } else {
587     TF_RETURN_IF_ERROR(
588         CopyDenseElementsDataToXlaFormat(literal, &info.content));
589     initializer = llvm::ConstantAggregateZero::get(global_type);
590   }
591 
592   // These globals will be looked up by name by GpuExecutable so we need to
593   // give them an external linkage.  Not all of their uses are visible in
594   // the LLVM IR so we can't give then a linkage that merely preserves their
595   // names (like available_externally), we also need to ensure that they stick
596   // around even if they're "unused".
597   //
598   // We may have to be more clever here in the future if we notice that we're
599   // keeping around too many globals because of their linkage.
600   llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
601       global_type, /*isConstant=*/should_emit_initializer,
602       llvm::GlobalValue::ExternalLinkage,
603       /*Initializer=*/initializer, global.getSymName(),
604       /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
605       /*AddressSpace=*/0,
606       /*isExternallyInitialized=*/false);
607   global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes));
608   module_->getGlobalList().push_back(global_for_const);
609 
610   info.symbol_name.assign(global.getSymName().begin(),
611                           global.getSymName().end());
612 
613   info.allocation_index =
614       global->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt();
615   ir_emitter_context_->constants().push_back(std::move(info));
616   return OkStatus();
617 }
618 
GetConditionalThunkConfig(mlir::lmhlo::CaseOp op,std::vector<ThunkSequence> branch_thunk_sequences)619 static ConditionalThunkConfig GetConditionalThunkConfig(
620     mlir::lmhlo::CaseOp op, std::vector<ThunkSequence> branch_thunk_sequences) {
621   ConditionalThunkConfig config;
622   config.branch_index_is_bool = op.getIndex()
623                                     .getType()
624                                     .cast<mlir::ShapedType>()
625                                     .getElementType()
626                                     .isInteger(
627                                         /*width=*/1);
628   config.branch_count = op.getBranches().size();
629   // Pass nullptr as the HloInstruction* to the branch_thunks
630   // constructors because these SequentialThunks are logically "part of"
631   // this ConditionalThunk, and shouldn't be profiled separately from it.
632   config.branch_thunks.reserve(branch_thunk_sequences.size());
633   for (auto& branch_thunk_sequence : branch_thunk_sequences) {
634     config.branch_thunks.emplace_back(new SequentialThunk(
635         Thunk::ThunkInfo(), std::move(branch_thunk_sequence)));
636   }
637   return config;
638 }
639 
EmitConditional(mlir::Operation * op)640 Status IrEmitterUnnested::EmitConditional(mlir::Operation* op) {
641   auto conditional = mlir::cast<mlir::lmhlo::CaseOp>(op);
642 
643   std::vector<ThunkSequence> branch_thunks;
644 
645   int branch_count = conditional.getBranches().size();
646   branch_thunks.reserve(branch_count);
647 
648   for (int j = 0; j < branch_count; ++j) {
649     mlir::Region* branch_computation = &conditional.getBranches()[j];
650     TF_ASSIGN_OR_RETURN(
651         auto ir_emitter,
652         IrEmitterUnnested::Create(hlo_module_config_, ir_emitter_context_));
653     TF_RETURN_IF_ERROR(ir_emitter->EmitLmhloRegion(branch_computation));
654     branch_thunks.push_back(std::move(*ir_emitter->ConsumeThunkSequence()));
655   }
656 
657   ConditionalThunkConfig config =
658       GetConditionalThunkConfig(conditional, std::move(branch_thunks));
659 
660   TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(conditional.getIndex()));
661   AddThunkToThunkSequence(std::unique_ptr<Thunk>(
662       new ConditionalThunk(GetThunkInfo(op), std::move(config), slice)));
663   return OkStatus();
664 }
665 
CreateLoad(llvm::Value * address,llvm::Type * data_type,int alignment_bytes)666 llvm::Value* IrEmitterUnnested::CreateLoad(llvm::Value* address,
667                                            llvm::Type* data_type,
668                                            int alignment_bytes) {
669   int data_bytes = data_type->getPrimitiveSizeInBits() /
670                    primitive_util::BitWidth(PrimitiveType::U8);
671   if (alignment_bytes == 0) {
672     return b_.CreateLoad(data_type,
673                          b_.CreateBitCast(address, data_type->getPointerTo()));
674   }
675 
676   int alignment_bitwidth =
677       alignment_bytes * primitive_util::BitWidth(PrimitiveType::U8);
678 
679   llvm::Value* output = llvm::ConstantInt::get(data_type, 0);
680   for (int offset_bytes = 0; offset_bytes < data_bytes;
681        offset_bytes += alignment_bytes) {
682     llvm::Value* offset_address = b_.CreateConstInBoundsGEP1_32(
683         b_.getInt8Ty(), address, offset_bytes, "offset_address");
684     llvm::Value* partial_value = b_.CreateLoad(b_.getIntNTy(alignment_bitwidth),
685                                                offset_address, "partial_value");
686     llvm::Value* zextd =
687         b_.CreateZExt(partial_value, output->getType(), "partial_value_zextd");
688     llvm::Value* shifted = b_.CreateShl(
689         zextd, llvm::ConstantInt::get(b_.getInt32Ty(), offset_bytes),
690         "partial_input_shifted");
691     output = b_.CreateAdd(output, shifted, "output_updated");
692   }
693   return output;
694 }
695 
CreateStore(llvm::Value * data,llvm::Value * address,int alignment_bytes)696 void IrEmitterUnnested::CreateStore(llvm::Value* data, llvm::Value* address,
697                                     int alignment_bytes) {
698   int data_bytes = data->getType()->getPrimitiveSizeInBits() /
699                    primitive_util::BitWidth(PrimitiveType::U8);
700   CHECK_GE(data_bytes, alignment_bytes);
701   if (alignment_bytes == 0) {
702     b_.CreateStore(data,
703                    b_.CreateBitCast(address, data->getType()->getPointerTo()));
704     return;
705   }
706 
707   int alignment_bitwidth =
708       alignment_bytes * primitive_util::BitWidth(PrimitiveType::U8);
709 
710   for (int offset_bytes = 0; offset_bytes < data_bytes;
711        offset_bytes += alignment_bytes) {
712     llvm::Value* offset_address = b_.CreateConstInBoundsGEP1_32(
713         b_.getInt8Ty(), address, offset_bytes, "offset_address");
714     llvm::Value* shifted_partial = b_.CreateTrunc(
715         b_.CreateLShr(data,
716                       llvm::ConstantInt::get(b_.getInt32Ty(), offset_bytes)),
717         b_.getIntNTy(alignment_bitwidth), "truncated_value");
718     b_.CreateStore(
719         shifted_partial,
720         b_.CreateBitCast(offset_address,
721                          b_.getIntNTy(alignment_bitwidth)->getPointerTo()));
722   }
723 }
724 
725 // Input = {dynamic array(with dynamic dimension meta data at the end)}
726 // Output = {static array, dynamic_dim0, dynamic_dim1}
EmitPadToStatic(mlir::Operation * op)727 Status IrEmitterUnnested::EmitPadToStatic(mlir::Operation* op) {
728   // TODO(jurahul): Create an op to represent PadToStatic.
729   auto pad_to_static = mlir::cast<mlir::lmhlo::CustomCallOp>(op);
730   int unroll_factor = 1;
731   std::string ir_name = GetIrNameFromLoc(pad_to_static.getLoc());
732 
733   const Shape& input_shape = GetShape(pad_to_static.getArgs().front());
734   TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
735                       CalculateLaunchDimensions(
736                           input_shape, ir_emitter_context_->gpu_device_info(),
737                           {unroll_factor}));
738   std::vector<llvm_ir::IrArray> ir_arrays;
739   TF_ASSIGN_OR_RETURN(auto kernel_thunk,
740                       BuildKernelThunk(pad_to_static, GetThunkInfo(op),
741                                        &ir_arrays, launch_dimensions));
742 
743   const llvm_ir::IrArray source_array = ir_arrays[0];
744   const llvm_ir::IrArray output_array = ir_arrays[1];
745   auto output_dim_arrays =
746       absl::Span<const llvm_ir::IrArray>(ir_arrays).subspan(2);
747 
748   llvm::Type* index_ty = GetIndexTypeForKernel(
749       pad_to_static, launch_dimensions.launch_bound(), &b_);
750 
751   // pseudo code for PadToStatic on a 2d array
752   //   int* source_array = input[0];
753   //   int* dest_array = output[0];
754   llvm::Value* source_buffer = source_array.GetBasePointer();
755   llvm::Value* raw_buffer =
756       b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo());
757 
758   // TODO(jurahul): input_shape here is the static shape of the input (which has
759   // a dynamic shape in XLA). Currently, we are mapping that to a static shaped
760   // memref. When we change that to a more appropriate representation in MLIR,
761   // fix this code to correctly deduce the static shape backing the dynamically
762   // shaped memref.
763   int64_t raw_data_size = ShapeUtil::ByteSizeOf(input_shape);
764 
765   //   int* dyn_dim0_size = source_array + meta_data_offset;
766   //   int* dyn_dim1_size = source_array + meta_data_offset + sizeof(int);
767   std::vector<llvm::Value*> dynamic_dims;
768   int alignment = raw_data_size % sizeof(int32_t);
769   for (int64_t i = 1; i < pad_to_static.getOutput().size(); ++i) {
770     // Dynamic size of each dimension is attached at the end of the source
771     // array(operand(0)). We need to extract these value.
772     const Shape& dim_shape = GetShape(pad_to_static.getOutput()[i]);
773     TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32)));
774 
775     const int64_t dim_index = i - 1;
776     llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
777         b_.getInt8Ty(), raw_buffer,
778         raw_data_size + dim_index * sizeof(int32_t));
779     llvm::Value* dyn_dim_size =
780         CreateLoad(metadata, b_.getInt32Ty(), alignment);
781     dynamic_dims.push_back(dyn_dim_size);
782   }
783 
784   // only one thread need to store the dynamic index
785   //   int thread_id = GetThreadId();
786   //   int block_id = GetBlockId();
787   //   if (thread_id == 0 && block_id == 0) {
788   //     *output[1] = *dyn_dim0_size;
789   //     *output[2] = *dyn_dim1_size;
790   //   }
791   KernelSupportLibrary{&b_}.If("is_thread_0", IsBlock0Thread0(&b_), [&] {
792     for (int64_t i = 1; i < pad_to_static.getOutput().size(); ++i) {
793       const int64_t dim_index = i - 1;
794       llvm::Value* dest_dim_size_address =
795           output_dim_arrays[dim_index].GetBasePointer();
796       // output[i] stores dynamic_dim_(i-1)
797       CreateStore(dynamic_dims[dim_index], dest_dim_size_address, alignment);
798     }
799   });
800 
801   //     int dyn_element_total = 1;
802   //     dyn_element_total *= *dyn_dim0_size;
803   //     dyn_element_total *= *dyn_dim1_size;
804   llvm::Value* dyn_element_total = llvm::ConstantInt::get(index_ty, 1);
805   for (llvm::Value* dynamic_dim : dynamic_dims) {
806     dyn_element_total =
807         b_.CreateMul(dyn_element_total,
808                      b_.CreateIntCast(dynamic_dim, dyn_element_total->getType(),
809                                       /*isSigned=*/true),
810                      /*Name=*/"dyn_element_total_pad");
811   }
812 
813   //   linear_index = block_id * threads_per_block + thread_id;
814   //   if (linear_index < max_num_element) {
815   //     Index static_index =
816   //         delinerized(linerized_index, static_dim0_size, static_dim1_size);
817   //     if (linerized_index < dyn_element_total) {
818   //       Index dyn_index =
819   //           delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
820   //       dest_array[dyn_index.dim0][dyn_index.dim1] =
821   //           source_array[static_index.dim0][static_index.dim1];
822   //     }
823   //   }
824   llvm_ir::BodyEmitter body_generator =
825       [&](const llvm_ir::IrArray::Index& array_index) -> Status {
826     llvm::Value* linearIndex =
827         array_index.Linearize(input_shape.dimensions(), &b_);
828     auto if_in_dyn_bounds = llvm_ir::EmitIfThenElse(
829         b_.CreateICmpULT(linearIndex, dyn_element_total),
830         llvm_ir::IrName(ir_name, "in_dyn_bounds"), &b_, false);
831     // Set IR builder insertion point to the body of the if structure.
832     llvm_ir::SetToFirstInsertPoint(if_in_dyn_bounds.true_block, &b_);
833     llvm_ir::IrArray::Index dyn_index(linearIndex, input_shape,
834                                       absl::MakeSpan(dynamic_dims), &b_);
835     output_array.EmitWriteArrayElement(
836         dyn_index,
837         source_array.EmitReadArrayElement(array_index, &b_, /*name=*/""), &b_,
838         /*use_linear_index=*/false);
839     return OkStatus();
840   };
841 
842   const Shape& data_shape = GetShape(pad_to_static.getOutput().front());
843   TF_RETURN_IF_ERROR(ParallelLoopEmitter(body_generator, data_shape,
844                                          launch_dimensions, &b_,
845                                          {unroll_factor})
846                          .EmitLoop(ir_name, index_ty));
847   thunk_sequence_.emplace_back(std::move(kernel_thunk));
848   return OkStatus();
849 }
850 
851 // Input = {dynamic array(with dynamic dimension meta data at the end)}
852 // Output = {static array, dynamic_dim0, dynamic_dim1}
EmitSliceToDynamic(mlir::Operation * op)853 Status IrEmitterUnnested::EmitSliceToDynamic(mlir::Operation* op) {
854   // TODO(jurahul): Create an op to represent SliceToDynamic.
855   auto slice_to_dynamic = mlir::cast<mlir::lmhlo::CustomCallOp>(op);
856   int unroll_factor = 1;
857   std::string ir_name = GetIrNameFromLoc(slice_to_dynamic.getLoc());
858 
859   const Shape& input_shape = GetShape(slice_to_dynamic.getArgs().front());
860   TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
861                       CalculateLaunchDimensions(
862                           input_shape, ir_emitter_context_->gpu_device_info(),
863                           {unroll_factor}));
864   llvm::Type* index_ty = GetIndexTypeForKernel(
865       slice_to_dynamic, launch_dimensions.launch_bound(), &b_);
866   std::vector<llvm_ir::IrArray> ir_arrays;
867   TF_ASSIGN_OR_RETURN(auto kernel_thunk,
868                       BuildKernelThunk(slice_to_dynamic, GetThunkInfo(op),
869                                        &ir_arrays, launch_dimensions));
870 
871   TF_RET_CHECK(slice_to_dynamic.getOutput().size() == 1);
872   const Shape& data_shape = GetShape(slice_to_dynamic.getOutput().front());
873 
874   // TODO(jurahul): data_shape here is the static shape of the output (which has
875   // a dynamic shape in XLA). Currently, we are mapping that to a static shaped
876   // memref. When we change that to a more appropriate representation in MLIR,
877   // fix this code to correctly deduce the static shape backing the dynamically
878   // shaped memref.
879 
880   // calculate the location where metadata needs to be inserted
881   //   int* dyn_dim0_size = dest_array + meta_data_offset;
882   //   int* dyn_dim1_size = dest_array + meta_data_offset + sizeof(int);
883   int32_t raw_data_size = ShapeUtil::ByteSizeOf(data_shape);
884 
885   // pseudo code for sliceToDynamic on a 2d array
886   //   int* source_array = input[0];
887   //   int* dest_array = output[0];
888   const llvm_ir::IrArray data_array = ir_arrays.back();
889   llvm::Value* dest_buffer = data_array.GetBasePointer();
890   llvm::Value* raw_buffer =
891       b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo());
892 
893   // Load dynamic dimensions from memory.
894   std::vector<llvm::Value*> dynamic_dims;
895   int alignment = raw_data_size % sizeof(int32_t);
896   for (int64_t i = 1; i < slice_to_dynamic.getArgs().size(); ++i) {
897     // const int64_t dim_index = i - 1;
898     llvm::Value* source_buffer = ir_arrays[i].GetBasePointer();
899     llvm::Type* source_buffer_pointee_type = ir_arrays[i].GetBasePointeeType();
900     llvm::LoadInst* dyn_dim_size =
901         Load(source_buffer_pointee_type, source_buffer, "dyn_dim_size");
902     dynamic_dims.push_back(dyn_dim_size);
903   }
904 
905   // only one thread need to store the dynamic index
906   //   int thread_id = GetThreadId();
907   //   int block_id = GetBlockId();
908   //   if (thread_id == 0 && block_id == 0) {
909   //     *dyn_dim0_size = *output[1];
910   //     *dyn_dim1_size = *output[2];
911   //   }
912   KernelSupportLibrary{&b_}.If("is_thread_0", IsBlock0Thread0(&b_), [&] {
913     for (int64_t i = 1; i < slice_to_dynamic.getArgs().size(); ++i) {
914       const int64_t dim_index = i - 1;
915       llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
916           b_.getInt8Ty(), raw_buffer,
917           raw_data_size + dim_index * sizeof(int32_t));
918       // output[i] stores dynamic_dim_(i-1)
919       CreateStore(dynamic_dims[dim_index], metadata, alignment);
920     }
921   });
922 
923   //     int dyn_element_total = 1;
924   //     dyn_element_total *= dyn_dim0_size;
925   //     dyn_element_total *= dyn_dim1_size;
926   llvm::Value* dyn_element_total = llvm::ConstantInt::get(index_ty, 1);
927   for (llvm::Value* dynamic_dim : dynamic_dims) {
928     dyn_element_total =
929         b_.CreateMul(dyn_element_total,
930                      b_.CreateIntCast(dynamic_dim, dyn_element_total->getType(),
931                                       /*isSigned=*/true),
932                      /*Name=*/"dyn_element_total_slice");
933   }
934 
935   //   linear_index = block_id * threads_per_block + thread_id;
936   //   if (linear_index < max_num_element) {
937   //     Index static_index =
938   //         delinerized(linerized_index, static_dim0_size, static_dim1_size);
939   //     if (linerized_index < dyn_element_total) {
940   //       Index dyn_index =
941   //           delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
942   //       dest_array[static_index.dim0][static_index.di] =
943   //           source_array[dyn_index.dim0][dyn_index.dim1];
944   //     }
945   //   }
946   llvm_ir::BodyEmitter body_generator =
947       [&](const llvm_ir::IrArray::Index& array_index) -> Status {
948     llvm::Value* linearIndex =
949         array_index.Linearize(input_shape.dimensions(), &b_);
950     auto if_in_dyn_bounds = llvm_ir::EmitIfThenElse(
951         b_.CreateICmpULT(linearIndex, dyn_element_total),
952         llvm_ir::IrName(ir_name, "in_dyn_bounds"), &b_, false);
953     // Set IR builder insertion point to the body of the if structure.
954     llvm_ir::SetToFirstInsertPoint(if_in_dyn_bounds.true_block, &b_);
955     llvm_ir::IrArray::Index dyn_index(linearIndex, input_shape,
956                                       absl::MakeSpan(dynamic_dims), &b_);
957 
958     data_array.EmitWriteArrayElement(
959         array_index,
960         ir_arrays[0].EmitReadArrayElement(dyn_index, &b_, /*name=*/"",
961                                           /*use_linear_index=*/false),
962         &b_);
963     return OkStatus();
964   };
965 
966   TF_RETURN_IF_ERROR(ParallelLoopEmitter(body_generator, data_shape,
967                                          launch_dimensions, &b_,
968                                          {unroll_factor})
969                          .EmitLoop(ir_name, index_ty));
970   thunk_sequence_.emplace_back(std::move(kernel_thunk));
971   return OkStatus();
972 }
973 
EmitConvolutionThunk(mlir::Operation * op)974 Status IrEmitterUnnested::EmitConvolutionThunk(mlir::Operation* op) {
975   using mlir::dyn_cast;
976   using mlir::lmhlo_gpu::Activation;
977   using mlir::lmhlo_gpu::ConvBackwardFilterOp;
978   using mlir::lmhlo_gpu::ConvBackwardInputOp;
979   using mlir::lmhlo_gpu::ConvForwardFusedOp;
980   using mlir::lmhlo_gpu::ConvForwardFusedSideInputOp;
981   using mlir::lmhlo_gpu::ConvForwardOp;
982 
983   // Last 2 operands of the convolution operation are the result and scratch.
984   std::vector<BufferAllocation::Slice> operand_slices;
985   int64_t num_operands = op->getNumOperands();
986   operand_slices.reserve(num_operands - 2);
987   for (mlir::Value operand : op->getOperands().drop_back(2)) {
988     TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(operand));
989     operand_slices.push_back(slice);
990   }
991 
992   mlir::Value conv_result = op->getOperand(num_operands - 2);
993   mlir::Value scratch_result = op->getOperand(num_operands - 1);
994   TF_ASSIGN_OR_RETURN(auto conv_result_slice, GetAllocationSlice(conv_result));
995   TF_ASSIGN_OR_RETURN(auto scratch_slice, GetAllocationSlice(scratch_result));
996 
997   auto apply_layout = [](const Shape& shape,
998                          mlir::ArrayRef<int64_t> minor_to_major) {
999     return ShapeUtil::MakeShapeWithLayout(shape.element_type(),
1000                                           shape.dimensions(), minor_to_major);
1001   };
1002 
1003   GpuConvDescriptor descriptor;
1004 
1005   auto fill_conv_descriptor = [&](auto op) {
1006     descriptor.operand0_shape =
1007         apply_layout(GetShape(op->getOperand(0)),
1008                      op.getBackendConfig().getOperand_0Layout());
1009     descriptor.operand1_shape =
1010         apply_layout(GetShape(op->getOperand(1)),
1011                      op.getBackendConfig().getOperand_1Layout());
1012     descriptor.result_shape = apply_layout(
1013         GetShape(conv_result), op.getBackendConfig().getResultLayout());
1014     descriptor.dnums = ConvertConvDimensionNumbers(op.getDimensionNumbers());
1015     descriptor.scratch_size = scratch_slice.size();
1016     mlir::DenseIntElementsAttr window_strides =
1017         op.getWindowStrides().getValue();
1018     mlir::DenseIntElementsAttr padding = op.getPadding().getValue();
1019     mlir::DenseIntElementsAttr lhs_dilation = op.getLhsDilation().getValue();
1020     mlir::DenseIntElementsAttr rhs_dilation = op.getRhsDilation().getValue();
1021     mlir::DenseElementsAttr window_reversal = op.getWindowReversal().getValue();
1022     for (auto index : llvm::seq<int>(0, window_strides.getNumElements())) {
1023       WindowDimension* dim = descriptor.window.add_dimensions();
1024       // Window size for a convolution is the same as the kernel size.
1025       // Kernel size of the convolution is operand1_shape. We need to look at
1026       // the convolution dimension numbers kernel spatial dimensions to get
1027       // the window size.
1028       int kernel_dim = descriptor.dnums.kernel_spatial_dimensions(index);
1029       dim->set_size(descriptor.operand0_shape.dimensions(kernel_dim));
1030       dim->set_stride(window_strides.getValues<int64_t>()[index]);
1031       dim->set_padding_low(padding.getValues<int64_t>()[index]);
1032       dim->set_padding_high(padding.getValues<int64_t>()[index]);
1033       dim->set_base_dilation(lhs_dilation.getValues<int64_t>()[index]);
1034       dim->set_window_dilation(rhs_dilation.getValues<int64_t>()[index]);
1035       dim->set_window_reversal(window_reversal.getValues<bool>()[index]);
1036     }
1037     descriptor.feature_group_count = op.getFeatureGroupCount();
1038     {
1039       auto* algorithm = descriptor.backend_config.mutable_algorithm();
1040       algorithm->set_algo_id(op.getBackendConfig().getAlgorithm());
1041       algorithm->set_math_type(op.getBackendConfig().getTensorOpsEnabled()
1042                                    ? se::dnn::AlgorithmProto::TENSOR_OP_MATH
1043                                    : se::dnn::AlgorithmProto::DEFAULT_MATH);
1044       for (int i = 0; i < op.getBackendConfig().getKnobIds().size(); ++i) {
1045         // N.B. tuning_knobs is a map rather than a repeated field, so this
1046         // doesn't require reserving space up front.
1047         (*algorithm
1048               ->mutable_tuning_knobs())[op.getBackendConfig().getKnobIds()[i]] =
1049             op.getBackendConfig().getKnobValues()[i];
1050       }
1051       algorithm->set_is_cudnn_frontend(
1052           op.getBackendConfig().getIsCudnnFrontend());
1053       auto workspace_size = op.getBackendConfig().getWorkspaceSize();
1054       if (workspace_size >= 0) {
1055         algorithm->mutable_workspace_size()->set_value(workspace_size);
1056       }
1057     }
1058     descriptor.backend_config.set_conv_result_scale(
1059         op.getResultScale().convertToDouble());
1060   };
1061 
1062   auto set_activation_mode = [&](auto op) -> Status {
1063     TF_ASSIGN_OR_RETURN(stream_executor::dnn::ActivationMode activation_mode,
1064                         ConvertConvActivationMode(op.getActivationMode()));
1065     descriptor.backend_config.set_activation_mode(
1066         static_cast<int64_t>(activation_mode));
1067     return OkStatus();
1068   };
1069 
1070   if (auto conv = dyn_cast<ConvForwardOp>(op)) {
1071     descriptor.kind = CudnnConvKind::kForward;
1072     fill_conv_descriptor(conv);
1073   } else if (auto conv = dyn_cast<ConvBackwardInputOp>(op)) {
1074     descriptor.kind = CudnnConvKind::kBackwardInput;
1075     fill_conv_descriptor(conv);
1076   } else if (auto conv = dyn_cast<ConvBackwardFilterOp>(op)) {
1077     descriptor.kind = CudnnConvKind::kBackwardFilter;
1078     fill_conv_descriptor(conv);
1079   } else if (auto conv = dyn_cast<ConvForwardFusedOp>(op)) {
1080     descriptor.kind = CudnnConvKind::kForwardActivation;
1081     fill_conv_descriptor(conv);
1082     TF_RETURN_IF_ERROR(set_activation_mode(conv));
1083   } else if (auto conv = dyn_cast<ConvForwardFusedSideInputOp>(op)) {
1084     descriptor.kind = CudnnConvKind::kForwardActivation;
1085     fill_conv_descriptor(conv);
1086     TF_RETURN_IF_ERROR(set_activation_mode(conv));
1087     descriptor.backend_config.set_side_input_scale(
1088         conv.getSideInputScale().convertToDouble());
1089   } else {
1090     return InternalError("Unexpected operation");
1091   }
1092   TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(descriptor, ""));
1093   AddThunkToThunkSequence(std::make_unique<ConvolutionThunk>(
1094       GetThunkInfo(op), std::move(config), std::move(operand_slices),
1095       conv_result_slice, scratch_slice));
1096   return OkStatus();
1097 }
1098 
EmitGemmThunk(mlir::Operation * op)1099 Status IrEmitterUnnested::EmitGemmThunk(mlir::Operation* op) {
1100   auto gemm = mlir::dyn_cast<mlir::lmhlo_gpu::GEMMOp>(op);
1101   TF_RET_CHECK(gemm != nullptr);
1102 
1103   TF_ASSIGN_OR_RETURN(auto a, GetAllocationSlice(gemm.getA()));
1104   TF_ASSIGN_OR_RETURN(auto b, GetAllocationSlice(gemm.getB()));
1105   TF_ASSIGN_OR_RETURN(auto c, GetAllocationSlice(gemm.getC()));
1106 
1107   TF_ASSIGN_OR_RETURN(GemmConfig config, GemmConfig::For(gemm));
1108   auto thunk =
1109       std::make_unique<GemmThunk>(GetThunkInfo(op), std::move(config), a, b, c);
1110 
1111   AddThunkToThunkSequence(std::move(thunk));
1112   return OkStatus();
1113 }
1114 
1115 #if GOOGLE_CUDA
1116 
EmitCublasLtMatmulThunk(mlir::Operation * op)1117 Status IrEmitterUnnested::EmitCublasLtMatmulThunk(mlir::Operation* op) {
1118   auto matmul = mlir::dyn_cast<mlir::lmhlo_gpu::CublasLtMatmulOp>(op);
1119   TF_RET_CHECK(matmul != nullptr);
1120 
1121   TF_ASSIGN_OR_RETURN(auto a, GetAllocationSlice(matmul.getA()));
1122   TF_ASSIGN_OR_RETURN(auto b, GetAllocationSlice(matmul.getB()));
1123   TF_ASSIGN_OR_RETURN(auto c, GetAllocationSlice(matmul.getC()));
1124   TF_ASSIGN_OR_RETURN(auto d, GetAllocationSlice(matmul.getD()));
1125 
1126   BufferAllocation::Slice bias;
1127   if (matmul.getBias() != nullptr) {
1128     TF_ASSIGN_OR_RETURN(bias, GetAllocationSlice(matmul.getBias()));
1129   }
1130 
1131   TF_ASSIGN_OR_RETURN(cublas_lt::MatmulPlan plan,
1132                       cublas_lt::MatmulPlan::For(matmul));
1133   auto thunk = std::make_unique<CublasLtMatmulThunk>(
1134       GetThunkInfo(op), std::move(plan), matmul.getAlgorithm(), a, b, c, d,
1135       bias);
1136 
1137   AddThunkToThunkSequence(std::move(thunk));
1138   return OkStatus();
1139 }
1140 
1141 #endif  // GOOGLE_CUDA
1142 
1143 namespace {
1144 // An MLIR value and its name as defined in the ODS spec.
1145 struct NamedValue {
1146   mlir::Value value;
1147   absl::string_view name;
1148 };
1149 
1150 // Determine if we enable the row optimized codegen.  When we have a
1151 // fusion with only point-wise operations, scalar broadcasting and row
1152 // broadcasting, we can trigger a kernel that vectorize the row loads.
1153 // This speed up the kernel, in particular on A100.
1154 // Returns a pair<bool, int>. The bool mean should we try to enable
1155 // row vectorization.  The int is the number of inputs with the higher
1156 // rank.
RowVectorizationEnabled(mlir::lmhlo::FusionOp fusion)1157 std::pair<bool, int> RowVectorizationEnabled(mlir::lmhlo::FusionOp fusion) {
1158   const auto is_row_major = [](mlir::Value value) {
1159     // Only tested when the inputs are row-major. So only
1160     // enable that case. Maybe it would works if only the
1161     // inner dimensions is contiguous.
1162     return LayoutUtil::IsMonotonicWithDim0Major(GetShape(value).layout());
1163   };
1164   bool row_vectorized =
1165       fusion.getFusionResults().size() == 1 &&  // Not tested with MOF.
1166       absl::c_all_of(GetHloOperands(fusion), is_row_major) &&
1167       absl::c_all_of(GetHloOutputs(fusion), is_row_major);
1168 
1169   // Check that the operations in the fusion are supported.  Each
1170   // supported operation (or category) must be manually vetted as XLA
1171   // only unrolls and relies on LLVM to vectorize. But this is brittle.
1172   // Currently tested and supported operations:
1173   // Elementwise, scalar and row broadcasting.
1174   //
1175   // We also detect at the same time if there is a row broadcasting
1176   // operation.
1177   bool some_row_broadcasting = false;
1178   auto out_rank =
1179       fusion.getFusionResults()[0].getType().cast<mlir::ShapedType>().getRank();
1180   int num_big_inputs = 0;
1181   for (mlir::Operation& op : fusion.getRegion().front()) {
1182     if (auto load = mlir::dyn_cast<mlir::bufferization::ToTensorOp>(op)) {
1183       auto rank = load.getResult().getType().cast<mlir::ShapedType>().getRank();
1184       num_big_inputs += static_cast<int>(rank == out_rank);
1185       continue;
1186     } else if (mlir::isa<mlir::memref::TensorStoreOp, mlir::lmhlo::TerminatorOp,
1187                          mlir::mhlo::ReturnOp, mlir::mhlo::ConstantOp,
1188                          mlir::lmhlo::ConstantOp>(op)) {
1189       continue;
1190     }
1191     HloOpcode opcode = *MhloToHloOpcode(&op);
1192     if (HloInstruction::IsOpElementwise(opcode)) {
1193       continue;
1194     }
1195 
1196     if (auto broadcast = mlir::dyn_cast<mlir::mhlo::BroadcastInDimOp>(op)) {
1197       const auto& broadcast_dimensions_size =
1198           broadcast.broadcast_dimensions().size();
1199       if (broadcast_dimensions_size == 0) {
1200         continue;
1201       }
1202       llvm::SmallVector<int64_t> broadcast_dimensions;
1203       broadcast_dimensions.reserve(broadcast_dimensions_size);
1204       for (const llvm::APInt& int_value : broadcast.broadcast_dimensions()) {
1205         broadcast_dimensions.push_back(int_value.getSExtValue());
1206       }
1207 
1208       auto rank = GetShape(broadcast.getResult()).rank();
1209       if (broadcast_dimensions.size() == 1 &&
1210           broadcast_dimensions.back() == (rank - 1)) {
1211         some_row_broadcasting = true;
1212         continue;
1213       }
1214     }
1215     VLOG(2) << "Row vectorization not enabled due to this op: "
1216             << MlirToString(&op);
1217     return std::make_pair(false, 0);
1218   }
1219   // Trigger only when there is a row broadcasting.
1220   return std::make_pair(row_vectorized && some_row_broadcasting,
1221                         num_big_inputs);
1222 }
1223 }  // namespace
1224 
1225 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
EmitCholeskyThunk(mlir::Operation * op)1226 Status IrEmitterUnnested::EmitCholeskyThunk(mlir::Operation* op) {
1227   auto cholesky_op = mlir::cast<mlir::lmhlo_gpu::CholeskyOp>(op);
1228 
1229   const Shape shape = GetShape(cholesky_op.getInput());
1230   int ndim = shape.dimensions_size();
1231   CHECK_GE(ndim, 2);
1232   int64_t n = shape.dimensions(ndim - 1);
1233 
1234   const auto& dims = shape.dimensions();
1235   int64_t batch_size =
1236       std::accumulate(dims.begin(), dims.end() - 2, int64_t{1},
1237                       [](int64_t a, int64_t b) { return a * b; });
1238 
1239   TF_ASSIGN_OR_RETURN(auto operand_buffer,
1240                       GetAllocationSlice(cholesky_op.getInput()));
1241   TF_ASSIGN_OR_RETURN(auto a_buffer,
1242                       GetAllocationSlice(cholesky_op.getOutput()));
1243   TF_ASSIGN_OR_RETURN(auto workspace_buffer,
1244                       GetAllocationSlice(cholesky_op.getScratch()));
1245   TF_ASSIGN_OR_RETURN(auto info_buffer,
1246                       GetAllocationSlice(cholesky_op.getInfo()));
1247 
1248   ThunkSequence thunks;
1249 
1250   if (operand_buffer != a_buffer) {
1251     thunks.push_back(std::make_unique<DeviceToDeviceCopyThunk>(
1252         GetThunkInfo(op),
1253         /*source_address=*/operand_buffer,
1254         /*destination_buffer=*/a_buffer,
1255         /*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
1256   }
1257 
1258   CholeskyOptions options;
1259   options.set_lower(cholesky_op.getIsLower());
1260   thunks.push_back(std::make_unique<CholeskyThunk>(
1261       GetThunkInfo(op), options,
1262       PtxOptsFromDebugOptions(hlo_module_config_.debug_options()), a_buffer,
1263       workspace_buffer, info_buffer, shape.element_type(), batch_size, n));
1264 
1265   // Elide the sequential thunk if there's no copy.
1266   if (thunks.size() == 1) {
1267     AddThunkToThunkSequence(std::move(thunks[0]));
1268   } else {
1269     AddThunkToThunkSequence(
1270         std::make_unique<SequentialThunk>(GetThunkInfo(op), std::move(thunks)));
1271   }
1272 
1273   return OkStatus();
1274 }
1275 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1276 
EmitCustomCallThunk(mlir::Operation * op)1277 Status IrEmitterUnnested::EmitCustomCallThunk(mlir::Operation* op) {
1278   auto custom_call = mlir::cast<mlir::lmhlo::CustomCallOp>(op);
1279   const std::string call_target_name = custom_call.getCallTargetName().str();
1280 
1281   void* call_target = CustomCallTargetRegistry::Global()->Lookup(
1282       call_target_name, std::string(platform_name()));
1283   if (!call_target) {
1284     return Unimplemented(
1285         "No registered implementation for custom call to \"%s\"",
1286         call_target_name);
1287   }
1288 
1289   std::vector<CustomCallThunk::OptionalSlice> operands;
1290   std::vector<CustomCallThunk::OptionalSlice> results;
1291   if (custom_call.getTargetArgMapping()) {
1292     auto values_to_slices_with_token_holes =
1293         [&](mlir::ValueRange operands,
1294             mlir::ArrayRef<int64_t> op_to_target_mapping, int64_t num_target)
1295         -> StatusOr<std::vector<CustomCallThunk::OptionalSlice>> {
1296       std::vector<CustomCallThunk::OptionalSlice> slices(num_target);
1297       for (auto index_and_value_it :
1298            llvm::zip(op_to_target_mapping, operands)) {
1299         int64_t index = std::get<0>(index_and_value_it);
1300         mlir::Value value = std::get<1>(index_and_value_it);
1301         TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
1302                             GetAllocationSlice(value));
1303         slices[index] = slice;
1304       }
1305       return slices;
1306     };
1307 
1308     mlir::lmhlo::CustomCallTargetArgMappingAttr target_mapping =
1309         *custom_call.getTargetArgMapping();
1310     TF_ASSIGN_OR_RETURN(operands, values_to_slices_with_token_holes(
1311                                       custom_call.getArgs(),
1312                                       target_mapping.getArgsToTargetArgs(),
1313                                       target_mapping.getNumArgs()));
1314     TF_ASSIGN_OR_RETURN(results, values_to_slices_with_token_holes(
1315                                      custom_call.getOutput(),
1316                                      target_mapping.getResultsToTargetResults(),
1317                                      target_mapping.getNumResults()));
1318   } else {
1319     auto values_to_slices = [&](mlir::ValueRange values)
1320         -> StatusOr<std::vector<CustomCallThunk::OptionalSlice>> {
1321       std::vector<CustomCallThunk::OptionalSlice> slices;
1322       for (mlir::Value value : values) {
1323         TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
1324                             GetAllocationSlice(value));
1325         slices.push_back(slice);
1326       }
1327       return slices;
1328     };
1329 
1330     TF_ASSIGN_OR_RETURN(operands, values_to_slices(custom_call.getArgs()));
1331     TF_ASSIGN_OR_RETURN(results, values_to_slices(custom_call.getOutput()));
1332   }
1333 
1334   CustomCallThunk::CustomCallTarget custom_call_target;
1335 
1336   // For information about this calling convention, see
1337   // xla/g3doc/custom_call.md.
1338   switch (custom_call.getApiVersion()) {
1339     case mlir::mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL:
1340       using original_call_type =
1341           void (*)(CustomCallThunk::Stream /*stream*/, void** /*buffers*/,
1342                    const char* /*opaque*/, size_t /*opaque_len*/);
1343       custom_call_target = [call_target](CustomCallThunk::Stream stream,
1344                                          void** buffers, const char* opaque,
1345                                          size_t opaque_len,
1346                                          XlaCustomCallStatus*) {
1347         auto typed_call_target =
1348             reinterpret_cast<original_call_type>(call_target);
1349         typed_call_target(stream, buffers, opaque, opaque_len);
1350       };
1351       break;
1352     case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING:
1353     case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED:
1354       using status_returning_call_type =
1355           void (*)(CustomCallThunk::Stream /*stream*/, void** /*buffers*/,
1356                    const char* /*opaque*/, size_t /*opaque_len*/,
1357                    XlaCustomCallStatus* /*status*/);
1358       custom_call_target =
1359           reinterpret_cast<status_returning_call_type>(call_target);
1360       break;
1361     default:
1362       return InternalError("Unknown custom-call API version enum value: %d",
1363                            custom_call.getApiVersion());
1364   }
1365 
1366   auto thunk = std::make_unique<CustomCallThunk>(
1367       GetThunkInfo(op), std::move(custom_call_target), std::move(operands),
1368       std::move(results), custom_call.getBackendConfig().str());
1369   AddThunkToThunkSequence(std::move(thunk));
1370   return OkStatus();
1371 }
1372 
EmitFftThunk(mlir::Operation * op)1373 Status IrEmitterUnnested::EmitFftThunk(mlir::Operation* op) {
1374   auto fft_op = mlir::cast<mlir::lmhlo::FftOp>(op);
1375   const Shape operand_shape = GetShape(fft_op.getOperand());
1376   const Shape output_shape = GetShape(fft_op.getOutput());
1377   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand_shape.layout()));
1378   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(output_shape.layout()));
1379 
1380   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice arg_slice,
1381                       GetAllocationSlice(fft_op.getOperand()));
1382   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dest_slice,
1383                       GetAllocationSlice(fft_op.getOutput()));
1384   TF_ASSIGN_OR_RETURN(
1385       xla::FftType fft_type,
1386       ConvertFftType(mlir::mhlo::stringifyFftType(fft_op.getFftType())));
1387   auto fft_length_values = fft_op.getFftLength().getValues<int64_t>();
1388   std::vector<int64_t> fft_length(fft_length_values.begin(),
1389                                   fft_length_values.end());
1390 
1391   AddThunkToThunkSequence(
1392       std::make_unique<FftThunk>(GetThunkInfo(op), fft_type, fft_length,
1393                                  /*input_buffer=*/arg_slice,
1394                                  /*output_buffer=*/dest_slice,
1395                                  /*input_shape=*/operand_shape,
1396                                  /*output_shape=*/output_shape));
1397   return OkStatus();
1398 }
1399 
1400 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
EmitTriangularSolveCustomCall(mlir::Operation * op)1401 Status IrEmitterUnnested::EmitTriangularSolveCustomCall(mlir::Operation* op) {
1402   auto custom_call = mlir::cast<mlir::lmhlo::CustomCallOp>(op);
1403 
1404   auto operands = op->getOperands();
1405   TF_RET_CHECK(operands.size() == 4);
1406 
1407   // We expect Fortran layout for everything other than the temp buffer (the
1408   // last operand).  Fortran layout is not XLA default layout with elements 0
1409   // and 1 swapped.  For example instead of default layout {3,2,1,0} we'd have
1410   // Fortran layout {2,3,1,0}.
1411   TF_RET_CHECK(absl::c_all_of(operands.drop_back(1), [&](mlir::Value v) {
1412     const Shape& shape = GetShape(v);
1413     const Layout& layout = shape.layout();
1414     int n = layout.minor_to_major_size();
1415     if (n < 2) {
1416       return false;
1417     }
1418     // Unfortunately the HLO -> LMHLO -> HLO conversion loses layout information
1419     // if the shape has any dimensions of size 1: In that case, the new HLO
1420     // (which we see here) will have an arbitrary value for the location of the
1421     // size-1 dimension.  Just skip this assertion if the shape has any
1422     // degenerate dimensions.
1423     if (absl::c_any_of(shape.dimensions(),
1424                        [](int64_t dim) { return dim == 1; })) {
1425       return true;
1426     }
1427     return layout.minor_to_major(0) == n - 2 &&
1428            layout.minor_to_major(1) == n - 1 &&
1429            std::is_sorted(layout.minor_to_major().begin() + 2,
1430                           layout.minor_to_major().end(),
1431                           std::greater<int64_t>());
1432   }));
1433 
1434   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice a_slice,
1435                       GetAllocationSlice(operands[0]));
1436   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice b_slice,
1437                       GetAllocationSlice(operands[1]));
1438   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
1439                       GetAllocationSlice(operands[2]));
1440   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice temp_slice,
1441                       GetAllocationSlice(operands[3]));
1442 
1443   const Shape b_shape = GetShape(operands[1]);
1444   const PrimitiveType elem_ty = b_shape.element_type();
1445   TriangularSolveOptions backend_config;
1446   TF_RETURN_IF_ERROR(tensorflow::HumanReadableJsonToProto(
1447       custom_call.getBackendConfig().str(), &backend_config));
1448 
1449   ThunkSequence thunks;
1450 
1451   // Triangular solve is in-place on 'b', so copy 'b' to the output if they
1452   // aren't the same buffer.
1453   if (b_slice != result_slice) {
1454     thunks.push_back(std::make_unique<DeviceToDeviceCopyThunk>(
1455         Thunk::ThunkInfo(),
1456         /*source_address=*/b_slice,
1457         /*destination_buffer=*/result_slice,
1458         /*mem_size=*/ShapeUtil::ByteSizeOf(b_shape)));
1459   }
1460 
1461   int64_t m = b_shape.dimensions(b_shape.rank() - 2);
1462   int64_t n = b_shape.dimensions(b_shape.rank() - 1);
1463   int64_t batch_size = std::accumulate(
1464       b_shape.dimensions().begin(), b_shape.dimensions().end() - 2, int64_t{1},
1465       [](int64_t a, int64_t b) { return a * b; });
1466   int64_t elem_size = ShapeUtil::ByteSizeOfPrimitiveType(elem_ty);
1467   int64_t a_batch_stride =
1468       backend_config.left_side() ? m * m * elem_size : n * n * elem_size;
1469   int64_t b_batch_stride = m * n * elem_size;
1470   thunks.push_back(std::make_unique<TriangularSolveThunk>(
1471       GetThunkInfo(op), backend_config,
1472       PtxOptsFromDebugOptions(hlo_module_config_.debug_options()),
1473       /*a_buffer=*/a_slice, /*b_buffer=*/result_slice, temp_slice, elem_ty,
1474       batch_size, m, n, a_batch_stride, b_batch_stride));
1475 
1476   // Elide the sequential thunk if there's no copy.
1477   if (thunks.size() == 1) {
1478     AddThunkToThunkSequence(std::move(thunks[0]));
1479   } else {
1480     AddThunkToThunkSequence(
1481         std::make_unique<SequentialThunk>(GetThunkInfo(op), std::move(thunks)));
1482   }
1483   return OkStatus();
1484 }
1485 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1486 
1487 // Convert the following form of fusion region:
1488 //   fusion() {
1489 //     %0 = tensor_load %external_memref0
1490 //     %1 = tensor_load %external_memref1
1491 //     ...
1492 //     tensor_store %ret, %external_memref2
1493 //   }
1494 // to
1495 //   fusion(%external_memref0, %external_memref1) (^bb(%0, %1) {
1496 //     ...
1497 //     mhlo.return %ret
1498 //   })
1499 //
1500 // So that it's suitable for MHLO -> XLA HLO conversion.
1501 // This function won't be needed once ElementalIrEmitter migrates to take MHLO
1502 // instead.
ProcessFusionForConversion(mlir::Region * region,std::vector<Shape> * operand_shapes,std::vector<Shape> * output_shapes)1503 static Status ProcessFusionForConversion(mlir::Region* region,
1504                                          std::vector<Shape>* operand_shapes,
1505                                          std::vector<Shape>* output_shapes) {
1506   std::vector<mlir::bufferization::ToTensorOp> loads;
1507   std::vector<mlir::memref::TensorStoreOp> stores;
1508 
1509   region->walk([&](mlir::bufferization::ToTensorOp load) {
1510     if (load.getMemref().getParentRegion() != region) {
1511       loads.push_back(load);
1512     }
1513   });
1514 
1515   region->walk([&](mlir::memref::TensorStoreOp store) {
1516     if (store.getMemref().getParentRegion() != region) {
1517       stores.push_back(store);
1518     }
1519   });
1520 
1521   for (auto& load : loads) {
1522     auto arg = region->addArgument(load.getType(), region->getLoc());
1523     load.replaceAllUsesWith(arg);
1524     Shape shape = GetShape(load.getResult());
1525     operand_shapes->push_back(std::move(shape));
1526     load.erase();
1527   }
1528 
1529   std::vector<mlir::Value> returned_values;
1530   for (auto store : stores) {
1531     Shape shape = GetShape(store.getMemref());
1532     output_shapes->push_back(shape);
1533 
1534     returned_values.push_back(store.getTensor());
1535     store.erase();
1536   }
1537 
1538   region->back().back().erase();
1539   auto b = mlir::OpBuilder::atBlockEnd(&region->back());
1540   auto loc = returned_values[0].getLoc();
1541   b.create<mlir::mhlo::ReturnOp>(loc, returned_values);
1542   return OkStatus();
1543 }
1544 
1545 // We can iterate the output buffer in logical order instead of physical order
1546 // when it is safe and profitable to do so.
EnableLogicalIndexGenerationForOutput(LaunchDimensionsConfig launch_config,LaunchDimensions launch_dimensions,absl::Span<llvm_ir::IrArray> operand_arrays,absl::Span<llvm_ir::IrArray> output_arrays)1547 static bool EnableLogicalIndexGenerationForOutput(
1548     LaunchDimensionsConfig launch_config, LaunchDimensions launch_dimensions,
1549     absl::Span<llvm_ir::IrArray> operand_arrays,
1550     absl::Span<llvm_ir::IrArray> output_arrays) {
1551   // Safety checks.  Currently the logical index generation code has
1552   // limitations. Violating these conditions can give wrong output.
1553   if (launch_config.row_vectorized || launch_config.few_waves) return false;
1554   if (output_arrays.size() != 1) return false;
1555   const Shape& output_shape = output_arrays[0].GetShape();
1556   if (output_shape.is_dynamic()) return false;
1557   // Currently we require that the number of threads * unroll factor should
1558   // exactly equal the number of output elements. It simplifies bounds checking
1559   // within the LLVM code generated for the fused kernel.
1560   if (ShapeUtil::ElementsIn(output_shape) !=
1561       (launch_dimensions.launch_bound() * launch_config.unroll_factor)) {
1562     return false;
1563   }
1564   if (launch_dimensions.thread_counts_per_block().y > 1 ||
1565       launch_dimensions.thread_counts_per_block().z > 1) {
1566     return false;
1567   }
1568   // Safety checks finish.
1569 
1570   // Profitability checks.
1571   // TODO(b/228209668) Investigate alternative profitability heuristics.
1572   // We should have a single output and not in row-major layout.
1573   if (LayoutUtil::IsMonotonicWithDim0Major(output_shape.layout())) {
1574     return false;
1575   }
1576   // We should have multiple inputs and all of them should have row-major
1577   // layout.
1578   if (operand_arrays.size() <= 1) return false;
1579   for (const llvm_ir::IrArray& input : operand_arrays) {
1580     const Shape& input_shape = input.GetShape();
1581     if (input_shape.is_dynamic()) return false;
1582     if (!LayoutUtil::IsMonotonicWithDim0Major(input_shape.layout())) {
1583       return false;
1584     }
1585   }
1586   return true;
1587 }
1588 
EmitLaunchFunc(mlir::Operation * op)1589 Status IrEmitterUnnested::EmitLaunchFunc(mlir::Operation* op) {
1590   auto launch_func = mlir::cast<mlir::gpu::LaunchFuncOp>(op);
1591   auto kernel_func =
1592       mlir::SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
1593           launch_func, launch_func.kernel());
1594   if (!kernel_func) {
1595     return InternalError("kernel '%s' not found",
1596                          launch_func.getKernelName().str());
1597   }
1598 
1599   // Lower kernel module to NVVM.
1600   auto gpu_module = kernel_func->getParentOfType<mlir::gpu::GPUModuleOp>();
1601   std::unique_ptr<llvm::Module> llvm_module = mlir::translateModuleToLLVMIR(
1602       gpu_module, module_->getContext(), gpu_module.getName());
1603   if (!llvm_module)
1604     return InternalError("Failed to translate GPU module to LLVM");
1605 
1606   // Add kernel to LLVM module.
1607   llvm_module->setDataLayout(module_->getDataLayout());
1608   llvm::Linker::linkModules(*module_, std::move(llvm_module));
1609 
1610   // Retrieve launch dimensions from arith.constant ops.
1611   auto get_dim3d = [](mlir::gpu::KernelDim3 dim3) {
1612     auto get_const = [](mlir::Value value) -> int64_t {
1613       auto const_op = value.getDefiningOp<mlir::arith::ConstantOp>();
1614       if (!const_op) return -1;
1615       auto attr = const_op.getValue().cast<mlir::IntegerAttr>();
1616       if (!attr) return -1;
1617       return attr.getValue().getSExtValue();
1618     };
1619     return LaunchDimensions::Dim3D{get_const(dim3.x), get_const(dim3.y),
1620                                    get_const(dim3.z)};
1621   };
1622   LaunchDimensions launch_dimensions(
1623       get_dim3d(launch_func.getGridSizeOperandValues()),
1624       get_dim3d(launch_func.getBlockSizeOperandValues()));
1625 
1626   // Create BufferSlice array from launch_func arguments, using the
1627   // attribute depicting which arguments are written by the kernel.
1628   std::vector<BufferSlice> slices;
1629   unsigned num_kernel_operands = launch_func.getNumKernelOperands();
1630   slices.reserve(num_kernel_operands);
1631   mlir::ArrayRef<mlir::Attribute> written_operands =
1632       mlir::getWrittenOperandsAttribute(launch_func).getValue();
1633   for (const auto& [operand, written] :
1634        llvm::zip_first(launch_func.operands(),
1635                        written_operands.take_back(num_kernel_operands))) {
1636     BufferSlice slice;
1637     TF_ASSIGN_OR_RETURN(slice.buffer_slice,
1638                         GetAllocationSlice(operand, &slice.constant_name));
1639     slice.shape = GetShape(operand);
1640     slice.written = written.cast<mlir::BoolAttr>().getValue();
1641     slices.push_back(std::move(slice));
1642   }
1643 
1644   // Add kernel prototype to module_, kernel thunk to thunk_sequence_.
1645   std::string kernel_name = GetIrNameFromLoc(launch_func.getLoc());
1646   std::vector<llvm_ir::IrArray> ir_arrays;
1647   TF_ASSIGN_OR_RETURN(
1648       std::unique_ptr<Thunk> kernel_thunk,
1649       BuildKernelThunkImpl(kernel_name, GetThunkInfo(op), slices, &ir_arrays,
1650                            launch_dimensions));
1651   thunk_sequence_.emplace_back(std::move(kernel_thunk));
1652 
1653   // Move function body into kernel prototype.
1654   llvm::Function* prototype_func = b_.GetInsertBlock()->getParent();
1655   llvm::Function* implementation_func =
1656       module_->getFunction(kernel_func.getName());
1657   prototype_func->getBasicBlockList().splice(
1658       prototype_func->end(), implementation_func->getBasicBlockList());
1659   for (const auto& [arg, ir_array] :
1660        llvm::zip_first(implementation_func->args(), ir_arrays)) {
1661     arg.replaceAllUsesWith(ir_array.GetBasePointer());
1662   }
1663   implementation_func->eraseFromParent();
1664 
1665   // Replace pre-existing return with unconditional branch to next block.
1666   llvm::Instruction* terminator =
1667       prototype_func->getEntryBlock().getTerminator();
1668   llvm::BranchInst::Create(&*std::next(prototype_func->begin()), terminator);
1669   terminator->eraseFromParent();
1670 
1671   return Status::OK();
1672 }
1673 
1674 // TODO(timshen): update the comment once the HandleFusion code path deleted.
1675 //
1676 // This is migrated from IrEmitter::HandleFusion() with IrEmitterUnnested as the
1677 // subclass. The logic is de-virtualized and less scattered.
EmitLoopFusion(mlir::Operation * op)1678 Status IrEmitterUnnested::EmitLoopFusion(mlir::Operation* op) {
1679   auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(op);
1680 
1681   TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation,
1682                       GetOrCreateSubComputationFromRegion(&fusion.getRegion(),
1683                                                           /*is_fusion=*/true));
1684 
1685   int unroll_factor;
1686   if (!MayPreventVectorization(fusion)) {
1687     unroll_factor = ComputeMaxUnrollFactor(fusion, hlo_module_config_);
1688   } else {
1689     unroll_factor = 1;
1690   }
1691 
1692   bool row_vectorized;
1693   int num_big_inputs;
1694   std::tie(row_vectorized, num_big_inputs) = RowVectorizationEnabled(fusion);
1695   bool few_waves = [fusion, row_vectorized, num_big_inputs]() mutable {
1696     for (mlir::Operation& op : fusion.getRegion().front()) {
1697       if (mlir::isa<mlir::bufferization::ToTensorOp,
1698                     mlir::memref::TensorStoreOp, mlir::lmhlo::TerminatorOp,
1699                     mlir::mhlo::ReturnOp, mlir::mhlo::ConstantOp>(op)) {
1700         continue;
1701       }
1702       HloOpcode opcode = *MhloToHloOpcode(&op);
1703       if (HloInstruction::IsOpElementwise(opcode)) {
1704         continue;
1705       }
1706       if (auto broadcast = mlir::dyn_cast<mlir::mhlo::BroadcastInDimOp>(op)) {
1707         if (broadcast.broadcast_dimensions().empty() ||
1708             // More then 2 bit inputs cause one speed regression.
1709             (row_vectorized && num_big_inputs <= 3)) {
1710           continue;
1711         }
1712       }
1713       VLOG(2) << "few_waves not enabled due to: " << MlirToString(&op);
1714       return false;
1715     }
1716     return true;
1717   }();
1718 
1719   Shape element_shape = GetShape(fusion.getOutputBuffers()[0]);
1720   LaunchDimensionsConfig launch_config{unroll_factor, few_waves,
1721                                        row_vectorized};
1722   // Check that the shapes is supported.
1723   if (launch_config.row_vectorized &&
1724       ThreadsPerBlockRowVectorized(element_shape,
1725                                    ir_emitter_context_->gpu_device_info(),
1726                                    launch_config) <= 0) {
1727     VLOG(2) << "Cancelling row_vectorization as the shape isn't supported.";
1728     launch_config.row_vectorized = false;
1729     launch_config.few_waves = false;
1730   }
1731 
1732   TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
1733                       CalculateLaunchDimensions(
1734                           element_shape, ir_emitter_context_->gpu_device_info(),
1735                           launch_config));
1736 
1737   std::vector<llvm_ir::IrArray> ir_arrays;
1738   Thunk* kernel_thunk;
1739   {
1740     TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> kernel_thunk_ptr,
1741                         BuildKernelThunk(fusion, GetThunkInfo(op), &ir_arrays,
1742                                          launch_dimensions));
1743     kernel_thunk = kernel_thunk_ptr.get();
1744     thunk_sequence_.emplace_back(std::move(kernel_thunk_ptr));
1745   }
1746 
1747   absl::Span<llvm_ir::IrArray> operand_arrays =
1748       absl::MakeSpan(ir_arrays).subspan(0, fusion.getInputBuffers().size());
1749   absl::Span<llvm_ir::IrArray> output_element_arrays =
1750       absl::MakeSpan(ir_arrays).subspan(fusion.getInputBuffers().size(),
1751                                         fusion.getOutputBuffers().size());
1752 
1753   GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
1754                                           GetNestedComputer());
1755   FusedIrEmitter fused_emitter(elemental_emitter);
1756 
1757   for (int i = 0; i < fusion.getInputBuffers().size(); i++) {
1758     auto* builder = &b_;
1759     auto ir_array = operand_arrays[i];
1760     fused_emitter.BindGenerator(
1761         *fused_computation->parameter_instruction(i),
1762         [builder, ir_array](llvm_ir::IrArray::Index index) {
1763           return ir_array.EmitReadArrayElement(index, builder);
1764         });
1765   }
1766   launch_config.logical_order = EnableLogicalIndexGenerationForOutput(
1767       launch_config, launch_dimensions, operand_arrays, output_element_arrays);
1768   TF_ASSIGN_OR_RETURN(
1769       auto element_generator,
1770       fused_emitter.GetGenerator(*fused_computation->root_instruction()));
1771 
1772   llvm::Type* index_type =
1773       GetIndexTypeForKernel(fusion, launch_dimensions.launch_bound(), &b_);
1774 
1775   TF_RETURN_IF_ERROR(
1776       ParallelLoopEmitter(element_generator, output_element_arrays,
1777                           launch_dimensions, &b_, launch_config)
1778           .EmitLoop(GetIrNameFromLoc(fusion->getLoc()), index_type));
1779 
1780   b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
1781   return OkStatus();
1782 }
1783 
1784 // Returns whether any of the rooots of the fusion are unnested reductions.
HasAnyUnnestedReductionRoot(mlir::lmhlo::FusionOp fusion)1785 static bool HasAnyUnnestedReductionRoot(mlir::lmhlo::FusionOp fusion) {
1786   return absl::c_any_of(fusion.getFusionRoots(), [&](mlir::Operation* op) {
1787     return IsReductionFromOrToContiguousDimensions(op);
1788   });
1789 }
1790 
EmitFusion(mlir::Operation * op)1791 Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) {
1792   auto fusion_op = mlir::cast<mlir::lmhlo::FusionOp>(op);
1793   TF_ASSIGN_OR_RETURN(
1794       const HloComputation* fused_computation,
1795       GetOrCreateSubComputationFromRegion(&fusion_op.getRegion(),
1796                                           /*is_fusion=*/true));
1797 
1798   if (HasAnyUnnestedReductionRoot(fusion_op)) {
1799     return EmitUnnestedReduction(fusion_op);
1800   }
1801 
1802   auto fusion_results = fusion_op.getFusionResults();
1803   TF_RET_CHECK(!fusion_results.empty());
1804   if (fusion_results.size() > 1) {
1805     // In the case of root tuple, it can be either reduce or slice input
1806     // fusion.
1807     if (IsInputFusibleSlices(op, /*verify_no_strides=*/true)) {
1808       // The emitter doesn't support all cases. If it's not supported, fallback
1809       // to ElementalIrEmitter.
1810       auto status = EmitInputFusibleNonStridedSlices(op);
1811       if (status.code() == tensorflow::error::FAILED_PRECONDITION) {
1812         return EmitLoopFusion(op);
1813       }
1814       return status;
1815     }
1816   }
1817 
1818   mlir::Operation* fusion_root = fusion_results[0].getDefiningOp();
1819   if (mlir::isa<mlir::mhlo::ScatterOp>(fusion_root)) {
1820     return EmitScatter(fusion_op, fused_computation);
1821   }
1822 
1823   if (!IsSingleInstructionFusion(fusion_op) &&
1824       CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
1825           fusion_op, ir_emitter_context_->allocations())) {
1826     return EmitDynamicUpdateSlice(fusion_op, fused_computation);
1827   }
1828 
1829   if (auto copy = mlir::dyn_cast<mlir::mhlo::CopyOp>(fusion_root);
1830       copy && IsSingleInstructionFusion(fusion_op)) {
1831     auto operands = GetHloOperands(fusion_op);
1832     auto outputs = GetHloOutputs(fusion_op);
1833     TF_RET_CHECK(operands.size() == 1);
1834     TF_RET_CHECK(outputs.size() == 1);
1835     auto operand_shape = GetShape(operands[0]);
1836     auto output_shape = GetShape(outputs[0]);
1837 
1838     CHECK(ShapeUtil::Compatible(operand_shape, output_shape));
1839     auto maybe_slice = GetAllocationSlice(operands[0]);
1840     if (LayoutUtil::Equal(operand_shape.layout(), output_shape.layout()) &&
1841         maybe_slice.ok()) {
1842       // Copy the operand into the output if it's not the same buffer already.
1843       auto operand_buffer = *maybe_slice;
1844       auto destination_buffer = *GetAllocationSlice(outputs[0]);
1845       if (operand_buffer != destination_buffer) {
1846         AddThunkToThunkSequence(std::make_unique<DeviceToDeviceCopyThunk>(
1847             GetThunkInfo(op),
1848             /*source_address=*/operand_buffer,
1849             /*destination_buffer=*/destination_buffer,
1850             /*mem_size=*/
1851             ByteSizeOf(operand_shape)));
1852       }
1853       return OkStatus();
1854     }
1855   }
1856 
1857   if (std::optional<TransposeDimsAndParams> descr =
1858           Match021Transpose(fused_computation)) {
1859     return Emit021Transpose(*descr, fusion_op);
1860   }
1861 
1862   return EmitLoopFusion(op);
1863 }
1864 
EmitExtraOutputsForReduce(const ReductionOutputMap & result_ir_arrays,const IrArray::Index & index,const ReductionCodegenInfo & reduction_info,const ExtraOutputGensMap & extra_output_gens)1865 Status IrEmitterUnnested::EmitExtraOutputsForReduce(
1866     const ReductionOutputMap& result_ir_arrays, const IrArray::Index& index,
1867     const ReductionCodegenInfo& reduction_info,
1868     const ExtraOutputGensMap& extra_output_gens) {
1869   // Compute all extra output values before writing them. This avoids
1870   // overwriting aliased input/output buffers before all reads occurred.
1871   absl::flat_hash_map<const HloInstruction*, llvm::Value*>
1872       extra_output_ir_values;
1873   for (const auto& p : extra_output_gens) {
1874     TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
1875                         p.second(index));
1876     extra_output_ir_values[p.first] = extra_output_ir_value;
1877   }
1878   for (const auto& p : extra_output_ir_values) {
1879     absl::Span<llvm_ir::IrArray const> result_ir = result_ir_arrays.at(p.first);
1880     CHECK_EQ(result_ir.size(), 1);
1881     result_ir[0].EmitWriteArrayElement(
1882         index, p.second, &b_, /*use_linear_index=*/
1883         reduction_info.GetNumPartialResults() == 1);
1884   }
1885   return OkStatus();
1886 }
1887 
AssertNonDeterminismIsOkay(const std::string & op_name)1888 Status IrEmitterUnnested::AssertNonDeterminismIsOkay(
1889     const std::string& op_name) {
1890   if (hlo_module_config_.debug_options().xla_gpu_deterministic_ops()) {
1891     return Unimplemented(
1892         "HLO instruction %s does not have a deterministic implementation, "
1893         "but run-to-run determinism is required by "
1894         "--xla_gpu_deterministic_ops.",
1895         op_name);
1896   }
1897   return OkStatus();
1898 }
1899 
EmitSelectAndScatter(mlir::Operation * op)1900 Status IrEmitterUnnested::EmitSelectAndScatter(mlir::Operation* op) {
1901   auto select_and_scatter_op = mlir::cast<mlir::lmhlo::SelectAndScatterOp>(op);
1902 
1903   const Shape source_shape = GetShape(select_and_scatter_op.getSource());
1904   const Shape operand_shape = GetShape(select_and_scatter_op.getOperand());
1905   const int64_t rank = operand_shape.rank();
1906 
1907   CHECK_EQ(rank, source_shape.rank());
1908   if (select_and_scatter_op.getWindowDimensions()) {
1909     CHECK_EQ(rank, select_and_scatter_op.getWindowDimensions()->size());
1910   }
1911 
1912   TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay(
1913       mlir::GetNameFromLoc(select_and_scatter_op.getLoc())));
1914 
1915   std::string name = GetIrNameFromLoc(select_and_scatter_op.getLoc());
1916 
1917   // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk
1918   // consisting of two thunks, an initializer KernelThunk that initializes
1919   // the output and another KernelThunk that accumulates the scattered
1920   // elements.
1921   ThunkSequence thunks;
1922   thunks.emplace_back();
1923   TF_ASSIGN_OR_RETURN(
1924       thunks.back(),
1925       BuildInitializerThunk(op, select_and_scatter_op.getInitValue(),
1926                             select_and_scatter_op.getOut()));
1927 
1928   TF_ASSIGN_OR_RETURN(
1929       LaunchDimensions launch_dimensions,
1930       CalculateLaunchDimensions(source_shape,
1931                                 ir_emitter_context_->gpu_device_info()));
1932   std::vector<llvm_ir::IrArray> ir_arrays;
1933   thunks.emplace_back();
1934   // Init value is not needed in IR emission.
1935   TF_ASSIGN_OR_RETURN(
1936       thunks.back(),
1937       BuildKernelThunk(
1938           select_and_scatter_op,
1939           {select_and_scatter_op.getOperand(),
1940            select_and_scatter_op.getSource(), select_and_scatter_op.getOut()},
1941           Thunk::ThunkInfo(), &ir_arrays, launch_dimensions));
1942 
1943   CHECK_EQ(ir_arrays.size(), 3);
1944   const IrArray& operand_array = ir_arrays[0];
1945   const IrArray& source_array = ir_arrays[1];
1946   const IrArray& out_array = ir_arrays[2];
1947 
1948   auto select_and_scatter_thunk =
1949       std::make_unique<SequentialThunk>(GetThunkInfo(op), std::move(thunks));
1950 
1951   llvm::Type* index_type = GetIndexTypeForKernel(
1952       select_and_scatter_op, launch_dimensions.launch_bound(), &b_);
1953   auto index_typed_constant = [&](uint64_t c) -> llvm::Constant* {
1954     return llvm::ConstantInt::get(index_type, c);
1955   };
1956 
1957   // kSelectAndScatter is implemented as two kernel launches: the first launch
1958   // initializes the output array to the given initial value,
1959   // and the second accumulates the "source" matrix to the
1960   // selected elements in the output array. The first launch is already
1961   // implemented by the initializer thunk generated earlier, so this function
1962   // only needs to take care of the select-and-scatter part.
1963   //
1964   // Pseudo code for select-and-scatter:
1965   //
1966   // for (coordinates S in the source):  # This loop is parallel.
1967   //   initialized_flag = false
1968   //   for (coordinates W in the window):
1969   //     I = S * stride + W - pad_low
1970   //     if I within bounds of operand:
1971   //       if !(initialized_flag and select(selected_value, operand(I))):
1972   //         selected_value = operand(I)
1973   //         selected_index = I
1974   //         initialized_flag = true
1975   //   if initialized_flag:
1976   //     output(selected_index) = scatter(output(selected_index), source(S))
1977   auto loop_body_emitter = [&](const IrArray::Index& source_index) -> Status {
1978     // Allocate space to keep the currently selected value, its index, and a
1979     // boolean flag if the value is initialized. The initialized_flag is set
1980     // false.
1981     llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
1982         llvm_ir::PrimitiveTypeToIrType(operand_shape.element_type(), module_),
1983         "selected_value_address", &b_);
1984 
1985     llvm::AllocaInst* selected_index_address =
1986         llvm_ir::EmitAllocaAtFunctionEntryWithCount(
1987             index_type, index_typed_constant(rank), "selected_index_address",
1988             &b_);
1989 
1990     llvm::AllocaInst* initialized_flag_address =
1991         llvm_ir::EmitAllocaAtFunctionEntry(b_.getInt1Ty(),
1992                                            "initialized_flag_address", &b_);
1993     Store(b_.getInt1(false), initialized_flag_address);
1994 
1995     // Create the inner loop to iterate over the window.
1996     llvm_ir::ForLoopNest window_loops(absl::StrCat(name, "inner"), &b_,
1997                                       index_type);
1998 
1999     DimensionVector window_size;
2000     mlir::DenseIntElementsAttr window_dimensions =
2001         select_and_scatter_op.getWindowDimensions().getValue();
2002     for (const auto& dim : window_dimensions) {
2003       window_size.push_back(dim.getSExtValue());
2004       CHECK_GT(dim.getSExtValue(), 0);
2005     }
2006 
2007     const IrArray::Index window_index = window_loops.AddLoopsForShape(
2008         ShapeUtil::MakeShape(operand_shape.element_type(), window_size),
2009         "window");
2010     llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(),
2011                                    &b_);
2012 
2013     // Compute the operand index to visit and evaluate the condition whether the
2014     // operand index is within the bounds. The unsigned comparison includes
2015     // checking whether the operand index >= 0.
2016     std::vector<llvm::Value*> operand_multi_index(source_index.size());
2017     llvm::Value* in_bounds_condition = b_.getInt1(true);
2018 
2019     auto strides = *select_and_scatter_op.getWindowStrides();
2020     auto paddings = *select_and_scatter_op.getPadding();
2021 
2022     for (auto stride_and_padding :
2023          llvm::enumerate(llvm::zip(strides, paddings))) {
2024       const int i = stride_and_padding.index();
2025       int64_t stride = std::get<0>(stride_and_padding.value()).getSExtValue();
2026       int64_t padding = std::get<1>(stride_and_padding.value()).getSExtValue();
2027 
2028       llvm::Value* strided_index =
2029           NSWMul(source_index[i], index_typed_constant(stride));
2030       operand_multi_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]),
2031                                       index_typed_constant(padding));
2032       llvm::Value* index_condition = ICmpULT(
2033           operand_multi_index[i],
2034           index_typed_constant(ShapeUtil::GetDimension(operand_shape, i)));
2035       in_bounds_condition = And(in_bounds_condition, index_condition);
2036     }
2037 
2038     // Only need to do something if the operand index is within the bounds.
2039     // First check if the initialized_flag is set.
2040     llvm_ir::LlvmIfData if_in_bounds =
2041         llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
2042     llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
2043     llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
2044         Load(initialized_flag_address->getAllocatedType(),
2045              initialized_flag_address),
2046         "initialized", &b_);
2047 
2048     // If the initialized_flag is false, initialize the selected value and index
2049     // with the currently visiting operand.
2050     llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &b_);
2051     const auto save_operand_index = [&](const IrArray::Index& operand_index) {
2052       for (int64_t i = 0; i < rank; ++i) {
2053         llvm::Value* selected_index_address_slot =
2054             InBoundsGEP(selected_index_address->getAllocatedType(),
2055                         selected_index_address, {b_.getInt32(i)});
2056         Store(operand_index[i], selected_index_address_slot);
2057       }
2058     };
2059     IrArray::Index operand_index(operand_multi_index, operand_shape,
2060                                  index_type);
2061     llvm::Value* operand_data =
2062         operand_array.EmitReadArrayElement(operand_index, &b_);
2063     Store(operand_data, selected_value_address);
2064     save_operand_index(operand_index);
2065     Store(b_.getInt1(true), initialized_flag_address);
2066 
2067     // If the initialized_flag is true, call the `select` function to
2068     // potentially update the selected value and index with the currently
2069     // visiting operand.
2070     llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &b_);
2071     llvm::Value* operand_address =
2072         operand_array.EmitArrayElementAddress(operand_index, &b_);
2073     llvm::AllocaInst* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
2074         llvm_ir::PrimitiveTypeToIrType(PRED, module_), "select_return_buffer",
2075         &b_);
2076 
2077     TF_ASSIGN_OR_RETURN(
2078         const HloComputation* select_computation,
2079         GetOrCreateSubComputationFromRegion(&select_and_scatter_op.getSelect(),
2080                                             /*is_fusion=*/false));
2081 
2082     TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
2083         *select_computation, {selected_value_address, operand_address},
2084         select_return_buffer));
2085     llvm::Value* result =
2086         Load(select_return_buffer->getAllocatedType(), select_return_buffer);
2087 
2088     // If the 'select' function returns false, update the selected value and the
2089     // index to the currently visiting operand.
2090     llvm::Value* cond =
2091         ICmpNE(result,
2092                llvm::ConstantInt::get(
2093                    llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
2094                "boolean_predicate");
2095     llvm_ir::LlvmIfData if_select_lhs =
2096         llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
2097     llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
2098     Store(Load(operand_array.GetElementLlvmType(), operand_address),
2099           selected_value_address);
2100     save_operand_index(operand_index);
2101 
2102     // If the initialized_flag is true, write to the selected index of the
2103     // output; otherwise the window is outside the source (in the padding) and
2104     // should be ignored.
2105     llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(),
2106                                    &b_);
2107     llvm_ir::LlvmIfData if_should_store = llvm_ir::EmitIfThenElse(
2108         Load(initialized_flag_address->getAllocatedType(),
2109              initialized_flag_address),
2110         "should-store", &b_, /*emit_else=*/false);
2111     llvm_ir::SetToFirstInsertPoint(if_should_store.true_block, &b_);
2112 
2113     // After iterating over the window elements, scatter the source element to
2114     // the selected index of the output. The value we store at the output
2115     // location is computed by calling the `scatter` function with the source
2116     // value and the current output value.
2117     std::vector<llvm::Value*> selected_multi_index;
2118     for (int64_t i = 0; i < rank; ++i) {
2119       llvm::Value* selected_index_address_slot =
2120           InBoundsGEP(selected_index_address->getAllocatedType(),
2121                       selected_index_address, {b_.getInt32(i)});
2122       selected_multi_index.push_back(
2123           Load(selected_index_address->getAllocatedType(),
2124                selected_index_address_slot));
2125     }
2126     const Shape output_shape = GetShape(select_and_scatter_op.getOut());
2127     llvm::Value* source_value_address =
2128         source_array.EmitArrayElementAddress(source_index, &b_);
2129     IrArray::Index selected_index(selected_multi_index, output_shape,
2130                                   operand_index.GetType());
2131     llvm::Value* output_value_address =
2132         out_array.EmitArrayElementAddress(selected_index, &b_);
2133 
2134     TF_ASSIGN_OR_RETURN(
2135         const HloComputation* scatter_computation,
2136         GetOrCreateSubComputationFromRegion(&select_and_scatter_op.getScatter(),
2137                                             /*is_fusion=*/false));
2138 
2139     return EmitAtomicOperationForNestedComputation(
2140         *scatter_computation, output_value_address, source_value_address,
2141         source_array.GetElementLlvmType());
2142   };
2143 
2144   AddThunkToThunkSequence(std::move(select_and_scatter_thunk));
2145   return ParallelLoopEmitter(loop_body_emitter, source_shape, launch_dimensions,
2146                              &b_)
2147       .EmitLoop(name, index_type);
2148 }
2149 
EmitWhile(mlir::Operation * op)2150 Status IrEmitterUnnested::EmitWhile(mlir::Operation* op) {
2151   auto while_op = mlir::cast<mlir::lmhlo::WhileOp>(op);
2152 
2153   auto cond_result = GetHloOutputs(while_op);
2154   TF_RET_CHECK(cond_result.size() == 1);
2155   TF_RET_CHECK(cond_result[0]
2156                    .getType()
2157                    .cast<mlir::ShapedType>()
2158                    .getElementType()
2159                    .isInteger(/*width=*/1))
2160       << "While condition computation must return bool";
2161 
2162   //  Build ForThunk for conformant while loops, otherwise build WhileThunk.
2163   if (while_op.getTripCount()) {
2164     TF_ASSIGN_OR_RETURN(auto thunk, BuildForThunk(while_op, GetThunkInfo(op),
2165                                                   *while_op.getTripCount()));
2166     AddThunkToThunkSequence(std::move(thunk));
2167   } else {
2168     TF_ASSIGN_OR_RETURN(auto thunk,
2169                         BuildWhileThunk(while_op, GetThunkInfo(op)));
2170     AddThunkToThunkSequence(std::move(thunk));
2171   }
2172   return OkStatus();
2173 }
2174 
EmitRngGetAndUpdateState(mlir::Operation * op)2175 Status IrEmitterUnnested::EmitRngGetAndUpdateState(mlir::Operation* op) {
2176   auto rng_op = mlir::dyn_cast<mlir::lmhlo::RngGetAndUpdateStateOp>(op);
2177 
2178   // Emit a kernel to increment the global state for Philox RNG algorithm.
2179   std::vector<llvm_ir::IrArray> ir_arrays;
2180   TF_ASSIGN_OR_RETURN(
2181       auto kernel_thunk,
2182       BuildKernelThunk(rng_op, rng_op.getState(), GetThunkInfo(op), &ir_arrays,
2183                        LaunchDimensions()));
2184   AddThunkToThunkSequence(std::move(kernel_thunk));
2185 
2186   llvm::Value* old_state =
2187       llvm_ir::RngGetAndUpdateState(rng_op.getDelta(), module_, &b_);
2188 
2189   const Shape shape = GetShape(rng_op.getState());
2190 
2191   llvm::Value* output_address = ir_arrays[0].EmitArrayElementAddress(
2192       llvm_ir::IrArray::Index(
2193           /*linear=*/b_.getInt64(0), shape, &b_),
2194       &b_, "rng_state_address");
2195   output_address = BitCast(
2196       output_address, llvm::PointerType::get(
2197                           old_state->getType(),
2198                           output_address->getType()->getPointerAddressSpace()));
2199   Store(old_state, output_address);
2200 
2201   return OkStatus();
2202 }
2203 
EmitScatter(mlir::Operation * op)2204 Status IrEmitterUnnested::EmitScatter(mlir::Operation* op) {
2205   ThunkSequence thunks;
2206 
2207   auto scatter_op = mlir::cast<mlir::lmhlo::ScatterOp>(op);
2208 
2209   if (!scatter_op.getUniqueIndices()) {
2210     TF_RETURN_IF_ERROR(
2211         AssertNonDeterminismIsOkay(GetIrNameFromLoc(scatter_op.getLoc())));
2212   }
2213 
2214   TF_ASSIGN_OR_RETURN(auto operand_buffer,
2215                       GetAllocationSlice(scatter_op.getOperand()));
2216   TF_ASSIGN_OR_RETURN(auto output_buffer,
2217                       GetAllocationSlice(scatter_op.getOutput()));
2218 
2219   // Copy the operand into the output if it's not the same buffer already.
2220   if (operand_buffer != output_buffer) {
2221     thunks.push_back(std::make_unique<DeviceToDeviceCopyThunk>(
2222         Thunk::ThunkInfo(),
2223         /*source_address=*/operand_buffer,
2224         /*destination_buffer=*/output_buffer,
2225         /*mem_size=*/
2226         ShapeUtil::ByteSizeOf(GetShape(scatter_op.getOutput()))));
2227   }
2228 
2229   const Shape& data_shape = GetShape(scatter_op.getUpdates());
2230   TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
2231                       CalculateLaunchDimensions(
2232                           data_shape, ir_emitter_context_->gpu_device_info()));
2233 
2234   // Create kernel thunk for all operands except the first one (`operand`). The
2235   // code generated for scatter below assumes that the input operand is already
2236   // copied into the output, so does not use it in codegen.
2237   std::vector<llvm_ir::IrArray> ir_arrays;
2238   thunks.emplace_back();
2239   TF_ASSIGN_OR_RETURN(
2240       thunks.back(),
2241       BuildKernelThunk(scatter_op, scatter_op.getOperands().drop_front(),
2242                        GetThunkInfo(op), &ir_arrays, launch_dimensions));
2243 
2244   CHECK_EQ(ir_arrays.size(), 3);
2245   const IrArray& scatter_indices = ir_arrays[0];
2246   const IrArray& updates = ir_arrays[1];
2247   const IrArray& output = ir_arrays[2];
2248 
2249   auto get_index_type = [&](int64_t launch_size) {
2250     return GetIndexTypeForKernel(scatter_op, launch_size, &b_);
2251   };
2252 
2253   TF_RETURN_IF_ERROR(EmitScatter(
2254       thunks.back().get(), scatter_op, launch_dimensions, output,
2255       /*scatter_indices_gen=*/
2256       [&](const IrArray::Index& index) {
2257         return scatter_indices.EmitReadArrayElement(index, &b_,
2258                                                     "scatter_index");
2259       },
2260       /*updates_gen=*/
2261       [&](const IrArray::Index& index) {
2262         return updates.EmitReadArrayElement(index, &b_, "update");
2263       },
2264       /* get_index_type=*/
2265       get_index_type));
2266 
2267   // Elide the sequential thunk if there's no copy.
2268   if (thunks.size() == 1) {
2269     AddThunkToThunkSequence(std::move(thunks[0]));
2270   } else {
2271     AddThunkToThunkSequence(
2272         std::make_unique<SequentialThunk>(GetThunkInfo(op), std::move(thunks)));
2273   }
2274 
2275   return OkStatus();
2276 }
2277 
EmitScatter(Thunk * thunk,mlir::lmhlo::ScatterOp scatter,const LaunchDimensions & launch_dimensions,const llvm_ir::IrArray & output,const llvm_ir::ElementGenerator & scatter_indices_gen,const llvm_ir::ElementGenerator & updates_gen,std::function<llvm::Type * (int64_t)> get_index_type)2278 Status IrEmitterUnnested::EmitScatter(
2279     Thunk* thunk, mlir::lmhlo::ScatterOp scatter,
2280     const LaunchDimensions& launch_dimensions, const llvm_ir::IrArray& output,
2281     const llvm_ir::ElementGenerator& scatter_indices_gen,
2282     const llvm_ir::ElementGenerator& updates_gen,
2283     std::function<llvm::Type*(int64_t)> get_index_type) {
2284   const Shape operand_shape = GetShape(scatter.getOperand());
2285   CHECK(ShapeUtil::Equal(GetShape(scatter.getOutput()), operand_shape));
2286 
2287   TF_ASSIGN_OR_RETURN(
2288       const HloComputation* update_computation,
2289       GetOrCreateSubComputationFromRegion(&scatter.getUpdateComputation(),
2290                                           /*is_fusion=*/false));
2291 
2292   ScatterDescriptor desc;
2293   desc.name = GetIrNameFromLoc(scatter.getLoc());
2294   desc.operand_shape = operand_shape;
2295   desc.scatter_indices_shape = GetShape(scatter.getScatterIndices());
2296   desc.updates_shape = GetShape(scatter.getUpdates());
2297   desc.dim_numbers = scatter.getScatterDimensionNumbers();
2298   desc.unique_indices = scatter.getUniqueIndices();
2299   desc.update_computation = update_computation;
2300   desc.output = output;
2301   desc.scatter_indices_gen = scatter_indices_gen;
2302   desc.updates_gen = updates_gen;
2303   desc.get_index_type = get_index_type;
2304   return EmitScatter(desc, thunk, launch_dimensions);
2305 }
2306 
EmitScatter(const ScatterDescriptor & desc,Thunk * thunk,const LaunchDimensions & launch_dimensions)2307 Status IrEmitterUnnested::EmitScatter(
2308     const ScatterDescriptor& desc, Thunk* thunk,
2309     const LaunchDimensions& launch_dimensions) {
2310   if (!desc.unique_indices) {
2311     TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay(desc.name));
2312   }
2313   auto loop_body_emitter = [&](const IrArray::Index& index) -> Status {
2314     std::vector<llvm::Value*> raw_window_multidim;
2315     std::vector<llvm::Value*> input_scatter_multidim;
2316     std::vector<int64_t> raw_window_bounds;
2317 
2318     // Partition the index into window indices and scatter indices.
2319     for (int64_t i = 0, e = index.size(); i != e; ++i) {
2320       // For window indices also remember the window size, this comes in handy
2321       // later.
2322       if (llvm::is_contained(desc.dim_numbers.getUpdateWindowDims(), i)) {
2323         raw_window_multidim.push_back(index[i]);
2324         raw_window_bounds.push_back(desc.updates_shape.dimensions(i));
2325       } else {
2326         input_scatter_multidim.push_back(index[i]);
2327       }
2328     }
2329     DCHECK_EQ(raw_window_multidim.size(),
2330               desc.dim_numbers.getUpdateWindowDims().size());
2331 
2332     // Apply inserted_window_dims to the window dimensions.
2333     int64_t raw_window_multidim_idx = 0;
2334     llvm::SmallVector<llvm::Value*> input_window_multidim;
2335     llvm::SmallVector<int64_t> input_window_bounds;
2336     const int64_t rank = desc.operand_shape.rank();
2337     input_window_bounds.reserve(rank);
2338     input_window_multidim.reserve(rank);
2339 
2340     for (int64_t i = 0; i != rank; ++i) {
2341       if (llvm::is_contained(desc.dim_numbers.getInsertedWindowDims(), i)) {
2342         input_window_bounds.push_back(1);  // Trivial dimension.
2343         input_window_multidim.push_back(index.GetConstantWithIndexType(0));
2344       } else {
2345         input_window_bounds.push_back(
2346             raw_window_bounds[raw_window_multidim_idx]);
2347         input_window_multidim.push_back(
2348             raw_window_multidim[raw_window_multidim_idx]);
2349         ++raw_window_multidim_idx;
2350       }
2351     }
2352     DCHECK_EQ(input_window_multidim.size(), desc.operand_shape.rank());
2353 
2354     // Insert a 1 dimension at the end if index_vector_dim requests one.
2355     Shape scatter_indices_shape_fixed = desc.scatter_indices_shape;
2356     if (desc.dim_numbers.getIndexVectorDim() ==
2357         desc.scatter_indices_shape.rank()) {
2358       scatter_indices_shape_fixed.add_dimensions(1);
2359       scatter_indices_shape_fixed.mutable_layout()->add_minor_to_major(
2360           desc.dim_numbers.getIndexVectorDim());
2361     }
2362 
2363     // Now load the indices corresponding to the current window from
2364     // scatter_indices.
2365     std::vector<llvm::Value*> raw_scatter_index_multidim =
2366         input_scatter_multidim;
2367     raw_scatter_index_multidim.insert(raw_scatter_index_multidim.begin() +
2368                                           desc.dim_numbers.getIndexVectorDim(),
2369                                       nullptr);
2370     llvm::Value* is_in_bounds = b_.getTrue();
2371     for (int64_t i = 0,
2372                  e = desc.dim_numbers.getScatterDimsToOperandDims().size();
2373          i != e; ++i) {
2374       // Our index is stored along index_vector_dim, insert that into the lookup
2375       // index into scatter_indices.
2376       raw_scatter_index_multidim[desc.dim_numbers.getIndexVectorDim()] =
2377           index.GetConstantWithIndexType(i);
2378       llvm_ir::IrArray::Index raw_scatter_index_index(
2379           raw_scatter_index_multidim, scatter_indices_shape_fixed,
2380           index.GetType());
2381 
2382       int64_t operand_dim = desc.dim_numbers.getScatterDimsToOperandDims()[i];
2383       TF_ASSIGN_OR_RETURN(
2384           llvm::Value* const loaded_scatter_index,
2385           desc.scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape(
2386               scatter_indices_shape_fixed, desc.scatter_indices_shape, &b_)));
2387       // And add the index to our window index. This yields the output index.
2388       llvm::Value* casted_scatter_index =
2389           IntCast(loaded_scatter_index, index.GetType(),
2390                   /*isSigned=*/true);
2391       llvm::Value* dim_offset =
2392           Add(input_window_multidim[operand_dim], casted_scatter_index);
2393       input_window_multidim[operand_dim] = dim_offset;
2394 
2395       // Also do the bounds check now.
2396       int64_t max_index = desc.operand_shape.dimensions(operand_dim) -
2397                           input_window_bounds[operand_dim] + 1;
2398       // is_in_bounds = index >= 0 && index < dim_size-window_size+1
2399       //   --> index u< dim_size-window_size+1
2400       is_in_bounds =
2401           And(is_in_bounds, ICmpULT(casted_scatter_index,
2402                                     index.GetConstantWithIndexType(max_index)));
2403     }
2404 
2405     llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse(
2406         is_in_bounds, "scatter.in_bounds", &b_, /*emit_else=*/false);
2407     llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_);
2408     // All done, now just read from the calculated input from the window, and do
2409     // an atomic store to the calculated location in the output.
2410     llvm_ir::IrArray::Index input_window_index(
2411         input_window_multidim, desc.output.GetShape(), index.GetType());
2412     llvm::Value* output_address =
2413         desc.output.EmitArrayElementAddress(input_window_index, &b_);
2414     llvm::Value* input_address = llvm_ir::EmitAllocaAtFunctionEntry(
2415         llvm_ir::PrimitiveTypeToIrType(desc.updates_shape.element_type(),
2416                                        module_),
2417         "input_address", &b_);
2418     TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
2419                         desc.updates_gen(index));
2420     Store(input_ir_value, input_address);
2421 
2422     if (!desc.unique_indices) {
2423       return EmitAtomicOperationForNestedComputation(
2424           *desc.update_computation, output_address, input_address,
2425           desc.output.GetElementLlvmType());
2426     } else {
2427       return EmitCallToNestedComputation(*desc.update_computation,
2428                                          {output_address, input_address},
2429                                          output_address);
2430     }
2431   };
2432 
2433   // Launch a kernel that reads every element in the updates tensor. We could
2434   // also do one kernel per window instead if bounds checks turn out to be a
2435   // bottleneck.
2436   return ParallelLoopEmitter(loop_body_emitter, desc.updates_shape,
2437                              launch_dimensions, &b_)
2438       .EmitLoop(desc.name,
2439                 desc.get_index_type(launch_dimensions.launch_bound()));
2440 }
2441 
2442 // This transformation should be migrated off. See b/171334474.
2443 StatusOr<HloComputation*>
GetOrCreateSubComputationFromRegion(mlir::Region * region,bool is_fusion)2444 IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region,
2445                                                        bool is_fusion) {
2446   std::unique_ptr<HloModule>& module = scratch_nested_computations_[region];
2447   if (module == nullptr) {
2448     std::vector<Shape> operand_shapes, output_shapes;
2449     if (is_fusion) {
2450       mlir::Operation* clone = region->getParentOp()->clone();
2451       region = &mlir::cast<mlir::lmhlo::FusionOp>(clone).getRegion();
2452       TF_RETURN_IF_ERROR(
2453           ProcessFusionForConversion(region, &operand_shapes, &output_shapes));
2454     }
2455 
2456     xla::XlaComputation xla_computation;
2457     mlir::MlirToHloConversionOptions options;
2458     options.propagate_layouts = true;
2459     options.propagate_bitcast_layouts_to_backend_config = true;
2460     options.legalize_node_names = false;
2461     TF_RETURN_IF_ERROR(
2462         ConvertRegionToComputation(region, &xla_computation, options));
2463 
2464     if (is_fusion) {
2465       region->getParentOp()->erase();
2466     }
2467 
2468     TF_ASSIGN_OR_RETURN(auto program_shape, xla_computation.GetProgramShape());
2469     TF_ASSIGN_OR_RETURN(
2470         module, HloModule::CreateFromProto(xla_computation.proto(),
2471                                            HloModuleConfig(program_shape)));
2472 
2473     if (is_fusion) {
2474       HloComputation* fused_computation = module->entry_computation();
2475 
2476       CHECK_EQ(operand_shapes.size(), fused_computation->num_parameters());
2477       for (int i = 0; i < fused_computation->num_parameters(); i++) {
2478         *fused_computation->parameter_instruction(i)
2479              ->mutable_shape()
2480              ->mutable_layout() = operand_shapes[i].layout();
2481       }
2482       HloInstruction* root = fused_computation->root_instruction();
2483       // Manually fold Tuple(GTE(a, 0), GTE(a, 1), GTE(a, 2), ...) to a.
2484       // FusedIrEmitter doesn't take GTE ops because we aim to elimiate tuples
2485       // as much as possible.
2486       if (root->opcode() == HloOpcode::kTuple) {
2487         [&] {
2488           HloInstruction* real_root = nullptr;
2489           int expected_tuple_index = 0;
2490           for (HloInstruction* operand : root->operands()) {
2491             if (operand->opcode() != HloOpcode::kGetTupleElement) {
2492               return;
2493             }
2494             if (real_root == nullptr) {
2495               real_root = operand->mutable_operand(0);
2496             } else if (real_root != operand->operand(0)) {
2497               return;
2498             }
2499             if (expected_tuple_index != operand->tuple_index()) {
2500               return;
2501             }
2502             expected_tuple_index++;
2503           }
2504           fused_computation->set_root_instruction(real_root);
2505           std::vector<HloInstruction*> to_be_removed;
2506           to_be_removed.push_back(root);
2507           for (HloInstruction* operand : root->operands()) {
2508             to_be_removed.push_back(operand);
2509           }
2510           for (auto instr : to_be_removed) {
2511             TF_CHECK_OK(fused_computation->RemoveInstruction(instr));
2512           }
2513 
2514           root = real_root;
2515         }();
2516       }
2517 
2518       if (output_shapes.size() > 1) {
2519         CHECK(root->shape().IsTuple());
2520         CHECK_EQ(root->shape().tuple_shapes_size(), output_shapes.size());
2521 
2522         for (int i = 0; i < output_shapes.size(); i++) {
2523           *root->mutable_shape()->mutable_tuple_shapes(i) = output_shapes.at(i);
2524         }
2525       } else {
2526         CHECK_EQ(1, output_shapes.size());
2527         *root->mutable_shape() = output_shapes[0];
2528       }
2529     }
2530     // Post-process the generated computation:
2531     // * Sanitize constant names, so that they can be used as LLVM global
2532     // symbols.
2533     // * Propagate layouts for tuple types.
2534     for (HloComputation* computation : module->computations()) {
2535       for (HloInstruction* instr : computation->MakeInstructionPostOrder()) {
2536         if (instr->opcode() == HloOpcode::kConstant) {
2537           // Notice that IR emitters use the name of constants as LLVM symbol
2538           // names, therefore it's important to not let these constants in the
2539           // new module collide with constants in the original module by names.
2540           // Unique them by prepending the module name.
2541           //
2542           // TODO(timshen): A better solution would be to plumb the exact
2543           // constant names through original HLO -> LHLO -> MHLO -> HLO. This is
2544           // hard because XLA builder doesn't support setting names. Revisit
2545           // this once we get rid of this function, or don't rely on the op name
2546           // (which shouldn't be the identity) to generate LLVM symbols.
2547           instr->SetAndSanitizeName(llvm_ir::SanitizeConstantName(
2548               module->name() + "_" + instr->name()));
2549         }
2550       }
2551     }
2552   }
2553   return module->entry_computation();
2554 }
2555 
EmitSort(mlir::Operation * op)2556 Status IrEmitterUnnested::EmitSort(mlir::Operation* op) {
2557   auto sort_op = mlir::cast<mlir::lmhlo::SortOp>(op);
2558   ThunkSequence thunks;
2559 
2560   std::string op_name = GetIrNameFromLoc(sort_op.getLoc());
2561   std::vector<mlir::Value> operands = GetHloOperands(sort_op);
2562   const Shape& keys_shape = GetShape(operands[0]);
2563   int64_t dimension_to_sort = sort_op.getDimension();
2564   for (int64_t i = 0; i < operands.size(); ++i) {
2565     // We assume that the layout of all involved operands and outputs is the
2566     // same.
2567     TF_RET_CHECK(
2568         LayoutUtil::LayoutsInShapesEqual(keys_shape, GetShape(operands[i])));
2569     TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
2570         keys_shape, GetShape(GetHloOutputs(sort_op)[i])));
2571 
2572     // If possible, we share buffers. If that is not possible, we need to copy
2573     // the values, because the emitter does the sorting in-place.
2574     TF_ASSIGN_OR_RETURN(auto destination_buffer,
2575                         GetAllocationSlice(sort_op.getOutput()[i]));
2576     TF_ASSIGN_OR_RETURN(auto source_address,
2577                         GetAllocationSlice(sort_op.getOperands()[i]));
2578     if (destination_buffer != source_address) {
2579       // TODO(b/26783907): Figure out why we never seem to share buffers for
2580       // key/value sort.
2581       VLOG(2) << op_name << " requires initial D2D copy for operand " << i;
2582       thunks.push_back(std::make_unique<DeviceToDeviceCopyThunk>(
2583           Thunk::ThunkInfo(),
2584           /*source_address=*/source_address,
2585           /*destination_buffer=*/destination_buffer,
2586           /*mem_size=*/ShapeUtil::ByteSizeOf(GetShape(operands[i]))));
2587     }
2588   }
2589 
2590   uint64_t dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
2591   int64_t num_stages = Log2Ceiling(dimension_to_sort_bound);
2592   VLOG(2) << op_name << " requires " << num_stages << " stages.";
2593   CHECK_GE(1ULL << num_stages, dimension_to_sort_bound);
2594   CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound);
2595 
2596   // Naive C++ code for the outer loops:
2597   //
2598   // for (int64_t stage = 0; stage < Log2Ceiling(dimension_to_sort_bound);
2599   //     ++stage) {
2600   //   int64_t first_xor_mask = (1LL << (stage + 1)) - 1;
2601   //   SortInPlace(first_xor_mask);
2602   //   for (int64_t mask = stage - 1; mask >= 0; --mask) {
2603   //     int64_t later_xor_mask = 1LL << mask;
2604   //     SortInPlace(later_xor_mask);
2605   //   }
2606   // }
2607   //
2608   // This follows the alternative representation of the algorithm described on
2609   // Wikipedia: https://en.wikipedia.org/wiki/Bitonic_sorter
2610   //
2611   // Each mask specifies how to derive from one position in the array the
2612   // position with which it should be compared (we calculate the xor of the
2613   // position with the mask).
2614   // As an optimization, we can move the 'mask' loop to inside the
2615   // sorting/comparison loop if the comparisons happen within a small block of
2616   // the array. To make this work, we collect all consecutive masks that are
2617   // smaller than our chosen power of 2 tile size, and pass them to SortInPlace.
2618   // Each thread then processes one tile of data.
2619 
2620   const uint64_t kTileSize = std::min(2048ULL, 1ULL << num_stages);
2621 
2622   // If we cannot combine several xor masks together, we don't use tiling, so we
2623   // calculate the standard launch dimensions for the shape. However we only
2624   // need to iterate through ~half of the dimension to sort (rounded up to the
2625   // next highest power of 2), because each iteration compares one pair of
2626   // elements.
2627   Shape standard_iteration_shape = keys_shape;
2628   uint64_t standard_num_iterations_in_sort_dim = 1ULL << (num_stages - 1);
2629   standard_iteration_shape.set_dimensions(dimension_to_sort,
2630                                           standard_num_iterations_in_sort_dim);
2631   TF_ASSIGN_OR_RETURN(
2632       LaunchDimensions standard_launch_dimensions,
2633       CalculateLaunchDimensions(standard_iteration_shape,
2634                                 ir_emitter_context_->gpu_device_info()));
2635 
2636   // Calculate the launch dimensions for the case where we use tiling. We split
2637   // the dimension that should be sorted into tiles of size 'kTileSize'. This
2638   // means we first need to round 'dimension_to_sort_bound' up to be a multiple
2639   // of the tile size.
2640   int64_t rounded_bound = RoundUpTo(dimension_to_sort_bound, kTileSize);
2641   Shape iteration_shape = keys_shape;
2642 
2643   // We iterate through the element pairs that should be compared.
2644   uint64_t num_iterations_in_sort_dim = rounded_bound / 2;
2645   iteration_shape.set_dimensions(dimension_to_sort, num_iterations_in_sort_dim);
2646   uint64_t num_iterations = ShapeUtil::ElementsIn(iteration_shape);
2647 
2648   // For correctness reasons we need exactly 'kTileSize' / 2 many threads per
2649   // block. Each thread is responsible for copying exactly two adjacent elements
2650   // into shared memory, and then does a comparison of two possibly different
2651   // elements taken from shared memory.
2652   const uint64_t kThreadsPerBlock = kTileSize / 2;
2653 
2654   // Check whether we should use any tiling. We might not be able to use it if
2655   // we have not enough threads, or not enough shared memory. Also it does not
2656   // give a speedup if the tile size is < 128.
2657   int64_t total_shared_memory_needed = 0;
2658   for (int64_t i = 0; i < operands.size(); ++i) {
2659     total_shared_memory_needed +=
2660         kTileSize * ShapeUtil::ByteSizeOfPrimitiveType(
2661                         GetShape(operands[i]).element_type());
2662   }
2663   bool no_tiling =
2664       kTileSize < 128 ||
2665       kThreadsPerBlock >
2666           ir_emitter_context_->gpu_device_info().threads_per_block_limit ||
2667       total_shared_memory_needed >
2668           ir_emitter_context_->gpu_device_info().shared_memory_per_block;
2669   VLOG(2) << absl::StreamFormat(
2670       "%s %s use tiling. No tiling if any of the following is true: "
2671       "kTileSize=%d < 128, "
2672       "kThreadsPerBlock=%d > threads_per_block_limit=%d, "
2673       "total_shared_memory_needed=%d > shared_memory_per_block=%d",
2674       op_name, (no_tiling ? "won't" : "will"), kTileSize, kThreadsPerBlock,
2675       ir_emitter_context_->gpu_device_info().threads_per_block_limit,
2676       total_shared_memory_needed,
2677       ir_emitter_context_->gpu_device_info().shared_memory_per_block);
2678 
2679   uint64_t num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock);
2680   LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock);
2681   VLOG(2) << absl::StreamFormat("%s launch dims: %d blocks, %d threads/block",
2682                                 op_name, num_blocks, kThreadsPerBlock);
2683 
2684   std::vector<llvm_ir::IrArray> ir_arrays;
2685   auto emit_kernel = [&](absl::Span<const int64_t> xor_masks) {
2686     VLOG(2) << absl::StreamFormat(
2687         "%s uses kernel for xor masks [%s]", op_name,
2688         absl::StrJoin(xor_masks, ", ", [](std::string* out, int64_t xor_mask) {
2689           absl::StrAppendFormat(out, "0x%x", xor_mask);
2690         }));
2691     thunks.emplace_back();
2692     LaunchDimensions launch_dimensions = xor_masks.size() > 1
2693                                              ? tiled_launch_dimensions
2694                                              : standard_launch_dimensions;
2695     TF_ASSIGN_OR_RETURN(
2696         thunks.back(),
2697         BuildKernelThunk(sort_op, sort_op.getOutput(), Thunk::ThunkInfo(),
2698                          &ir_arrays, launch_dimensions));
2699     std::vector<IrArray> values_arrays;
2700     values_arrays.reserve(operands.size());
2701     for (int64_t i = 0; i < operands.size(); ++i) {
2702       values_arrays.push_back(ir_arrays[i]);
2703     }
2704     TF_ASSIGN_OR_RETURN(const HloComputation* comparator,
2705                         GetOrCreateSubComputationFromRegion(
2706                             &sort_op.getComparator(), /*is_fusion=*/false));
2707     return llvm_ir::EmitSortInPlace(
2708         dimension_to_sort, values_arrays, IrName(op_name), xor_masks, &b_,
2709         launch_dimensions,
2710         xor_masks.size() > 1 ? num_iterations_in_sort_dim
2711                              : standard_num_iterations_in_sort_dim,
2712         kTileSize,
2713         [&](absl::Span<llvm::Value* const> operands, llvm::Value* output) {
2714           return EmitCallToNestedComputation(*comparator, operands, output);
2715         });
2716   };
2717   std::vector<int64_t> xor_masks;
2718   for (int64_t stage = 0; stage < num_stages; ++stage) {
2719     for (int64_t mask = stage; mask >= 0; --mask) {
2720       int64_t xor_mask;
2721       if (mask == stage) {
2722         xor_mask = (1LL << (stage + 1)) - 1;
2723       } else {
2724         xor_mask = 1LL << mask;
2725       }
2726       if (xor_mask >= kTileSize || no_tiling) {
2727         if (!xor_masks.empty()) {
2728           TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
2729           xor_masks.clear();
2730         }
2731         TF_RETURN_IF_ERROR(emit_kernel({xor_mask}));
2732       } else {
2733         xor_masks.push_back(xor_mask);
2734       }
2735     }
2736   }
2737   if (!xor_masks.empty()) {
2738     TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
2739   }
2740   VLOG(2) << absl::StreamFormat(
2741       "%s requires %d thunks (including any D2D copies)", op_name,
2742       thunks.size());
2743 
2744   AddThunkToThunkSequence(
2745       std::make_unique<SequentialThunk>(GetThunkInfo(op), std::move(thunks)));
2746   return OkStatus();
2747 }
2748 
2749 template <typename ThunkType, typename OpT>
EmitReplicaOrPartitionId(mlir::Operation * op)2750 Status IrEmitterUnnested::EmitReplicaOrPartitionId(mlir::Operation* op) {
2751   auto casted = mlir::cast<OpT>(op);
2752   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
2753                       GetAllocationSlice(casted.getOperand()));
2754   auto thunk = std::make_unique<ThunkType>(GetThunkInfo(op), result_slice);
2755   AddThunkToThunkSequence(std::move(thunk));
2756   return OkStatus();
2757 }
2758 
EmitCollectivePermute(mlir::Operation * op)2759 Status IrEmitterUnnested::EmitCollectivePermute(mlir::Operation* op) {
2760   auto collective_permute_op = mlir::cast<mlir::lmhlo::CollectivePermuteOp>(op);
2761 
2762   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice source_slice,
2763                       GetAllocationSlice(collective_permute_op.getOperand()));
2764   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
2765                       GetAllocationSlice(collective_permute_op.getOutput()));
2766 
2767   const Shape shape = GetShape(collective_permute_op.getOperand());
2768   const int64_t replica_count = hlo_module_config_.replica_count();
2769   const int64_t partition_count = hlo_module_config_.num_partitions();
2770 
2771   if (NcclCollectivePermuteThunk::IsDegenerate(
2772           collective_permute_op, replica_count, partition_count)) {
2773     // For a degenerate collective permute, just generate a copy thunk.
2774     AddThunkToThunkSequence(std::make_unique<DeviceToDeviceCopyThunk>(
2775         GetThunkInfo(op),
2776         /*source_address=*/source_slice,
2777         /*destination_buffer=*/result_slice,
2778         /*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
2779   } else {
2780     const NcclCollectivePermuteThunk::Buffer buffer = {
2781         /*element_count=*/ShapeUtil::ElementsIn(shape),
2782         /*source_buffer=*/source_slice,
2783         /*destination_buffer=*/result_slice};
2784     auto thunk = std::make_unique<NcclCollectivePermuteThunk>(
2785         GetThunkInfo(op), collective_permute_op, replica_count, partition_count,
2786         buffer);
2787     AddThunkToThunkSequence(std::move(thunk));
2788   }
2789   return OkStatus();
2790 }
2791 
MaybeAddAllReduceStartThunkToMap(absl::flat_hash_map<mlir::Operation *,NcclAllReduceStartThunk * > & all_reduce_start_thunks,mlir::Operation * op,Thunk * thunk)2792 Status MaybeAddAllReduceStartThunkToMap(
2793     absl::flat_hash_map<mlir::Operation*, NcclAllReduceStartThunk*>&
2794         all_reduce_start_thunks,
2795     mlir::Operation* op, Thunk* thunk) {
2796   if (mlir::isa<mlir::lmhlo_gpu::AllReduceStartOp>(op)) {
2797     TF_RET_CHECK(all_reduce_start_thunks
2798                      .emplace(op, static_cast<NcclAllReduceStartThunk*>(thunk))
2799                      .second)
2800         << "all-reduce-start with this unique ID already seen";
2801   }
2802   return OkStatus();
2803 }
2804 
2805 template <typename NcclThunkType, typename OpTy>
EmitNcclThunk(mlir::Operation * untyped_op)2806 Status IrEmitterUnnested::EmitNcclThunk(mlir::Operation* untyped_op) {
2807   OpTy op = mlir::cast<OpTy>(untyped_op);
2808   int64_t replica_count = hlo_module_config_.replica_count();
2809   int64_t partition_count = hlo_module_config_.num_partitions();
2810   VLOG(2) << NcclThunkType::GetName() << "; replica count: " << replica_count
2811           << "; partition count: " << partition_count
2812           << "; operand count: " << op.getOperands().size()
2813           << "; NCCL is enabled: " << NcclThunkType::NcclIsEnabled();
2814 
2815   // A given collective op can be degenerate if across all groups formed
2816   // by it are singleton. In such a case, we don't need to do any communication
2817   // and we can just copy the input to the output.
2818   bool is_degenerate =
2819       NcclThunkType::IsDegenerate(op, replica_count, partition_count);
2820   bool should_use_nccl_thunk =
2821       !is_degenerate && NcclThunkType::CanImplement(op);
2822 
2823   // Stash relevant information in NcclCollectiveThunk::Buffer even if we may
2824   // not generate an NcclCollectiveThunk.
2825   std::vector<NcclCollectiveThunk::Buffer> buffers;
2826   buffers.reserve(op.getOperands().size());
2827   for (auto it : llvm::zip(op.getInputs(), op.getOutputs())) {
2828     mlir::Value operand = std::get<0>(it);
2829     mlir::Value result = std::get<1>(it);
2830     const Shape shape = GetShape(operand);
2831     TF_ASSIGN_OR_RETURN(auto source_slice, GetAllocationSlice(operand));
2832     TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSlice(result));
2833     buffers.push_back(NcclCollectiveThunk::Buffer{
2834         /*element_count=*/ShapeUtil::ElementsIn(shape),
2835         /*source_buffer=*/source_slice,
2836         /*destination_buffer=*/dest_slice});
2837   }
2838 
2839   if (should_use_nccl_thunk) {
2840     auto thunk =
2841         std::make_unique<NcclThunkType>(GetThunkInfo(op), op,
2842                                         /*buffers=*/std::move(buffers));
2843     // Record thunks for all-reduce-start ops as the done ops need them.
2844     TF_RETURN_IF_ERROR(MaybeAddAllReduceStartThunkToMap(
2845         all_reduce_start_thunks_, op, thunk.get()));
2846     AddThunkToThunkSequence(std::move(thunk));
2847     return OkStatus();
2848   }
2849 
2850   // Signal that all-reduce-start thunk not created with nullptr.
2851   TF_RETURN_IF_ERROR(
2852       MaybeAddAllReduceStartThunkToMap(all_reduce_start_thunks_, op, nullptr));
2853 
2854   if (!is_degenerate) {
2855     CollectiveOpGroupMode group_mode = NcclThunkType::GetGroupMode(op);
2856 
2857     std::string message = absl::StrFormat(
2858         "Requested %s not implemented on GPU; replica_count: %d; "
2859         "partition_count: %d, group_mode: %s, operand_count: %d; NCCL support: "
2860         "%d",
2861         NcclThunkType::GetName(), replica_count, partition_count,
2862         CollectiveOpGroupModeToString(group_mode), op.getOperands().size(),
2863         NcclThunkType::NcclIsEnabled());
2864     if (!op.getOperands().empty()) {
2865       const Shape shape = GetShape(op.getOperands().front());
2866       absl::StrAppendFormat(&message, "; first operand array element-type: %s",
2867                             PrimitiveType_Name(shape.element_type()));
2868     }
2869     return Unimplemented("%s", message);
2870   }
2871 
2872   VLOG(1) << "Collective call is degenerate, not doing NCCL call";
2873 
2874   // All-gather with one replica is simply the identity function. Buffer
2875   // assignment expects a copy, so that's what we do.
2876   ThunkSequence thunks;
2877   for (int64_t i = 0; i < buffers.size(); i++) {
2878     const Shape shape = GetShape(op.getOperands()[i]);
2879     thunks.push_back(std::make_unique<DeviceToDeviceCopyThunk>(
2880         buffers.size() == 1 ? GetThunkInfo(op) : Thunk::ThunkInfo(),
2881         /*source_address=*/buffers[i].source_buffer,
2882         /*destination_buffer=*/buffers[i].destination_buffer,
2883         /*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
2884   }
2885   if (thunks.size() == 1) {
2886     AddThunkToThunkSequence(std::move(thunks[0]));
2887   } else {
2888     AddThunkToThunkSequence(
2889         std::make_unique<SequentialThunk>(GetThunkInfo(op), std::move(thunks)));
2890   }
2891   return OkStatus();
2892 }
2893 
EmitAllReduceDone(mlir::Operation * op)2894 Status IrEmitterUnnested::EmitAllReduceDone(mlir::Operation* op) {
2895   auto done_op = mlir::cast<mlir::lmhlo_gpu::AllReduceDoneOp>(op);
2896   auto start_op =
2897       done_op.getToken().getDefiningOp<mlir::lmhlo_gpu::AllReduceStartOp>();
2898   auto it = all_reduce_start_thunks_.find(start_op);
2899   TF_RET_CHECK(it != all_reduce_start_thunks_.end())
2900       << "couldn't find thunk for all-reduce-start op";
2901 
2902   // Can be null if no all-reduce-start thunk was created (e.g. if the start op
2903   // is degenerate), in which case there's nothing to do here.
2904   if (it->second != nullptr) {
2905     AddThunkToThunkSequence(std::make_unique<NcclAllReduceDoneThunk>(
2906         GetThunkInfo(op), *it->second));
2907   }
2908   all_reduce_start_thunks_.erase(it);
2909   return OkStatus();
2910 }
2911 
GetShapedSlices(mlir::Operation::operand_range operands)2912 StatusOr<std::vector<ShapedSlice>> IrEmitterUnnested::GetShapedSlices(
2913     mlir::Operation::operand_range operands) {
2914   std::vector<ShapedSlice> shaped_slices;
2915   shaped_slices.reserve(operands.size());
2916   for (mlir::Value opnd : operands) {
2917     TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(opnd));
2918     shaped_slices.push_back(ShapedSlice{slice, GetShape(opnd)});
2919   }
2920   return shaped_slices;
2921 }
2922 
GetSlices(mlir::Operation::operand_range operands)2923 StatusOr<std::vector<BufferAllocation::Slice>> IrEmitterUnnested::GetSlices(
2924     mlir::Operation::operand_range operands) {
2925   std::vector<BufferAllocation::Slice> slices;
2926   slices.reserve(operands.size());
2927   for (mlir::Value opnd : operands) {
2928     TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(opnd));
2929     slices.push_back(slice);
2930   }
2931   return slices;
2932 }
2933 
EmitInfeed(mlir::Operation * op)2934 Status IrEmitterUnnested::EmitInfeed(mlir::Operation* op) {
2935   mlir::Operation::operand_range operands =
2936       mlir::cast<mlir::lmhlo::InfeedOp>(op).getOutputs();
2937   TF_ASSIGN_OR_RETURN(auto shaped_slices, GetShapedSlices(operands));
2938   auto thunk =
2939       std::make_unique<InfeedThunk>(GetThunkInfo(op), std::move(shaped_slices));
2940   AddThunkToThunkSequence(std::move(thunk));
2941 
2942   return OkStatus();
2943 }
2944 
EmitOutfeed(mlir::Operation * op)2945 Status IrEmitterUnnested::EmitOutfeed(mlir::Operation* op) {
2946   mlir::Operation::operand_range operands =
2947       mlir::cast<mlir::lmhlo::OutfeedOp>(op).getInputs();
2948   TF_ASSIGN_OR_RETURN(auto shaped_slices, GetShapedSlices(operands));
2949   auto thunk = std::make_unique<OutfeedThunk>(GetThunkInfo(op),
2950                                               std::move(shaped_slices));
2951   AddThunkToThunkSequence(std::move(thunk));
2952 
2953   return OkStatus();
2954 }
2955 
BuildKernelThunkImpl(absl::string_view name,Thunk::ThunkInfo thunk_info,absl::Span<const BufferSlice> slices,std::vector<llvm_ir::IrArray> * ir_arrays,const LaunchDimensions & launch_dimensions)2956 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildKernelThunkImpl(
2957     absl::string_view name, Thunk::ThunkInfo thunk_info,
2958     absl::Span<const BufferSlice> slices,
2959     std::vector<llvm_ir::IrArray>* ir_arrays,
2960     const LaunchDimensions& launch_dimensions) {
2961   // Figure out which buffer allocations need to be passed as arguments to our
2962   // kernel.  This is simply all of the allocations referenced in slices,
2963   // plus the XLA temp buffer (if we have it).  We always include the temp
2964   // buffer because even if the kernel itself doesn't use it, a nested
2965   // subcomputation within the kernel (e.g. a kMap's computation) might.
2966   absl::flat_hash_set<const BufferAllocation*> buffers_needed;
2967   for (const auto& slice : slices) {
2968     buffers_needed.insert(slice.buffer_slice.allocation());
2969   }
2970   std::optional<const BufferAllocation*> temp_buffer;
2971   for (const BufferAllocation& alloc : ir_emitter_context_->allocations()) {
2972     if (alloc.IsPreallocatedTempBuffer()) {
2973       if (!temp_buffer.has_value()) {
2974         // Retrieve the first seen temp buffer.
2975         temp_buffer = &alloc;
2976       }
2977     }
2978   }
2979   if (temp_buffer.has_value()) {
2980     buffers_needed.insert(*temp_buffer);
2981   }
2982 
2983   // We'll pass a pointer to each of the elements of `buffers` to our kernel, in
2984   // this order.
2985   std::vector<const BufferAllocation*> non_constant_buffers;
2986   absl::c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers),
2987                   [](const BufferAllocation* allocation) {
2988                     return !allocation->is_constant();
2989                   });
2990 
2991   absl::c_sort(non_constant_buffers,
2992                [](const BufferAllocation* a, const BufferAllocation* b) {
2993                  return a->index() < b->index();
2994                });
2995 
2996   llvm::Function* kernel = BuildKernelPrototype(name, non_constant_buffers);
2997 
2998   // Build a map from a BufferAllocation to the corresponding argument in our
2999   // kernel.
3000   absl::flat_hash_map<const BufferAllocation*, llvm::Value*> kernel_args;
3001   {
3002     auto arg_it = kernel->arg_begin();
3003     auto buffers_it = non_constant_buffers.begin();
3004     for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) {
3005       kernel_args[*buffers_it] = arg_it;
3006 
3007       // Annotate all allocations with LLVM's `noalias`.
3008       // There are three kinds of allocations:
3009       // * Read-only allocations, aka input parameters that are not aliased with
3010       // outputs.
3011       // * Read-write allocations, including all output buffers, some of which
3012       // may alias with input HLO parameters, but aliased HLO buffers are always
3013       // assigned with the same allocation.
3014       // * The temp buffer.
3015       //
3016       // Read-only allocations may overlap with each other, but since they are
3017       // not mutated, they can always be annotated with `noalias` per LLVM
3018       // semantics.
3019       //
3020       // Read-write allocations and the temp buffer don't overlap with any
3021       // allocations, therefore they can also be annotated with `noalias`.
3022       kernel->addParamAttr(
3023           arg_it->getArgNo(),
3024           llvm::Attribute::get(arg_it->getContext(), llvm::Attribute::NoAlias));
3025     }
3026   }
3027 
3028   absl::flat_hash_set<BufferAllocation::Slice> buffers_written;
3029   for (const auto& slice : slices) {
3030     if (slice.written) {
3031       buffers_written.insert(slice.buffer_slice);
3032     }
3033   }
3034 
3035   ir_arrays->clear();
3036 
3037   // For each buffer our kernel might want to touch, bind it to a value derived
3038   // from our kernel args.
3039   for (const BufferSlice& slice : slices) {
3040     const BufferAllocation::Slice& buffer_slice = slice.buffer_slice;
3041 
3042     llvm::Value* loc;
3043     if (!slice.constant_name.empty()) {
3044       loc = module_->getGlobalVariable(slice.constant_name);
3045       CHECK_NE(loc, nullptr)
3046           << "Could not find variable '" << slice.constant_name << "'";
3047     } else {
3048       CHECK(!buffer_slice.allocation()->is_constant());
3049       loc =
3050           InBoundsGEP(b_.getInt8Ty(), kernel_args.at(buffer_slice.allocation()),
3051                       {b_.getInt64(buffer_slice.offset())});
3052     }
3053 
3054     llvm::Type* ir_type = llvm_ir::ShapeToIrType(slice.shape, module_);
3055     llvm_ir::IrArray ir_array(CastToTypedValue(slice.shape, loc, &b_), ir_type,
3056                               slice.shape);
3057     if (!buffers_written.contains(slice.buffer_slice)) {
3058       ir_array.MarkInvariantOverWholeProgram(&loc->getContext());
3059     }
3060 
3061     ir_arrays->push_back(ir_array);
3062   }
3063 
3064   AnnotateThunkLaunchDimensions(launch_dimensions,
3065                                 std::string(kernel->getName()), module_);
3066 
3067   return {std::make_unique<KernelThunk>(thunk_info, non_constant_buffers,
3068                                         std::string(kernel->getName()),
3069                                         launch_dimensions)};
3070 }
3071 
BuildKernelThunk(mlir::Operation * op,mlir::ValueRange operands,Thunk::ThunkInfo thunk_info,std::vector<llvm_ir::IrArray> * ir_arrays,const LaunchDimensions & launch_dimensions)3072 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildKernelThunk(
3073     mlir::Operation* op, mlir::ValueRange operands, Thunk::ThunkInfo thunk_info,
3074     std::vector<llvm_ir::IrArray>* ir_arrays,
3075     const LaunchDimensions& launch_dimensions) {
3076   TF_RET_CHECK(!mlir::isa<mlir::lmhlo::FusionOp>(op));
3077 
3078   std::vector<BufferSlice> slices;
3079   slices.reserve(operands.size());
3080   for (mlir::Value operand : operands) {
3081     slices.emplace_back();
3082     auto& slice = slices.back();
3083     TF_ASSIGN_OR_RETURN(slice.buffer_slice,
3084                         GetAllocationSlice(operand, &slice.constant_name));
3085     slice.written = WritesMlirBuffer(op, operand);
3086     slice.shape = GetShape(operand);
3087   }
3088   std::string name = GetIrNameFromLoc(op->getLoc());
3089   return BuildKernelThunkImpl(name, thunk_info, slices, ir_arrays,
3090                               launch_dimensions);
3091 }
3092 
BuildKernelThunk(mlir::Operation * op,Thunk::ThunkInfo thunk_info,std::vector<llvm_ir::IrArray> * ir_arrays,const LaunchDimensions & launch_dimensions)3093 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildKernelThunk(
3094     mlir::Operation* op, Thunk::ThunkInfo thunk_info,
3095     std::vector<llvm_ir::IrArray>* ir_arrays,
3096     const LaunchDimensions& launch_dimensions) {
3097   if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
3098     auto operands = GetHloOperands(op);
3099     auto outputs = GetHloOutputs(op);
3100 
3101     std::vector<BufferSlice> slices;
3102     slices.reserve(operands.size() + outputs.size());
3103     for (mlir::Value operand : operands) {
3104       slices.emplace_back();
3105       BufferSlice& slice = slices.back();
3106       TF_ASSIGN_OR_RETURN(slice.buffer_slice,
3107                           GetAllocationSlice(operand, &slice.constant_name));
3108       slice.written = false;
3109       slice.shape = GetShape(operand);
3110     }
3111     for (mlir::Value output : outputs) {
3112       slices.emplace_back();
3113       BufferSlice& slice = slices.back();
3114       TF_ASSIGN_OR_RETURN(slice.buffer_slice,
3115                           GetAllocationSlice(output, &slice.constant_name));
3116       slice.written = true;
3117       slice.shape = GetShape(output);
3118     }
3119     std::string name = GetIrNameFromLoc(op->getLoc());
3120     return BuildKernelThunkImpl(name, thunk_info, slices, ir_arrays,
3121                                 launch_dimensions);
3122   }
3123   return BuildKernelThunk(op, op->getOperands(), thunk_info, ir_arrays,
3124                           launch_dimensions);
3125 }
3126 
BuildConstantInitializerThunk(absl::Span<const uint8_t> init_value,const BufferAllocation::Slice & dest,const Shape & output_shape)3127 std::unique_ptr<Thunk> IrEmitterUnnested::BuildConstantInitializerThunk(
3128     absl::Span<const uint8_t> init_value, const BufferAllocation::Slice& dest,
3129     const Shape& output_shape) {
3130   int64_t num_bytes = init_value.size();
3131   if (absl::c_all_of(init_value, [](uint8_t byte) { return byte == 0; })) {
3132     return std::make_unique<MemzeroThunk>(Thunk::ThunkInfo(), dest);
3133   }
3134 
3135   // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
3136   // repeating the literal 4 or 2 times, so long as the destination buffer is
3137   // an even multiple of 32 bits long.
3138   if ((num_bytes == 1 || num_bytes == 2) &&
3139       ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) {
3140     uint16_t pattern16;
3141     if (num_bytes == 1) {
3142       uint8_t b = init_value.front();
3143       pattern16 = uint16_t{b} | (uint16_t{b} << 8);
3144     } else {
3145       memcpy(&pattern16, init_value.data(), sizeof(pattern16));
3146     }
3147     uint32_t pattern32 = uint32_t{pattern16} | (uint32_t{pattern16} << 16);
3148     return std::make_unique<Memset32BitValueThunk>(Thunk::ThunkInfo(),
3149                                                    pattern32, dest);
3150   }
3151 
3152   // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
3153   // memset so long as all 32-bit words of the scalar are equal to each other.
3154   if (num_bytes >= 4 && num_bytes % 4 == 0 &&
3155       memcmp(init_value.data(), init_value.data() + 4, init_value.size() - 4) ==
3156           0) {
3157     uint32_t word;
3158     memcpy(&word, init_value.data(), sizeof(word));
3159     return std::make_unique<Memset32BitValueThunk>(Thunk::ThunkInfo(), word,
3160                                                    dest);
3161   }
3162 
3163   return nullptr;
3164 }
3165 
3166 StatusOr<std::unique_ptr<Thunk>>
TryBuildConstantInitializerThunk(mlir::Value init_value,mlir::Value dest)3167 IrEmitterUnnested::TryBuildConstantInitializerThunk(mlir::Value init_value,
3168                                                     mlir::Value dest) {
3169   mlir::DenseElementsAttr const_init;
3170   if (auto get_global_memref =
3171           mlir::dyn_cast_or_null<mlir::memref::GetGlobalOp>(
3172               init_value.getDefiningOp())) {
3173     auto global_memref =
3174         mlir::SymbolTable::lookupNearestSymbolFrom<mlir::memref::GlobalOp>(
3175             get_global_memref, get_global_memref.getNameAttr());
3176     if (global_memref.getConstant() && global_memref.getInitialValue()) {
3177       // If the initial value happens to be a constant, generate a specialized
3178       // thunk.
3179       const_init = global_memref.getInitialValue()
3180                        .getValue()
3181                        .cast<mlir::DenseElementsAttr>();
3182     }
3183   } else if (auto constant = mlir::dyn_cast_or_null<mlir::mhlo::ConstantOp>(
3184                  init_value.getDefiningOp())) {
3185     const_init = constant.value().dyn_cast<mlir::DenseElementsAttr>();
3186   }
3187 
3188   if (const_init) {
3189     std::vector<uint8_t> literal_bytes;
3190     TF_RETURN_IF_ERROR(
3191         CopyDenseElementsDataToXlaFormat(const_init, &literal_bytes));
3192 
3193     TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSlice(dest));
3194 
3195     const Shape dest_shape = GetShape(dest);
3196     auto thunk =
3197         BuildConstantInitializerThunk(literal_bytes, dest_slice, dest_shape);
3198     if (thunk) {
3199       return {std::move(thunk)};
3200     }
3201   }
3202   return std::unique_ptr<Thunk>();
3203 }
3204 
BuildInitializerThunk(mlir::Operation * op,mlir::Value init_value,mlir::Value dest)3205 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
3206     mlir::Operation* op, mlir::Value init_value, mlir::Value dest) {
3207   // initial value must be a scalar memref.
3208   auto init_type = init_value.getType().dyn_cast<mlir::MemRefType>();
3209   TF_RET_CHECK(init_type.getRank() == 0);
3210 
3211   TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> constant_init_thunk,
3212                       TryBuildConstantInitializerThunk(init_value, dest));
3213   if (constant_init_thunk) {
3214     return {std::move(constant_init_thunk)};
3215   }
3216 
3217   // Otherwise fall back to our slow initializer code. The thunk in this case
3218   // will just need the IR arrays for the initial value and the destination.
3219   const Shape dest_shape = GetShape(dest);
3220   TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
3221                       CalculateLaunchDimensions(
3222                           dest_shape, ir_emitter_context_->gpu_device_info()));
3223   std::vector<llvm_ir::IrArray> ir_arrays;
3224   TF_ASSIGN_OR_RETURN(
3225       std::unique_ptr<Thunk> kernel_thunk,
3226       BuildKernelThunk(op, {init_value, dest}, Thunk::ThunkInfo(), &ir_arrays,
3227                        launch_dimensions));
3228 
3229   const llvm_ir::IrArray init_array = ir_arrays[0];
3230   const llvm_ir::IrArray dest_array = ir_arrays[1];
3231 
3232   std::string name = GetIrNameFromLoc(op->getLoc());
3233   TF_RETURN_IF_ERROR(ParallelLoopEmitter(
3234                          [=](const IrArray::Index& index) {
3235                            return init_array.EmitReadArrayElement(index, &b_);
3236                          },
3237                          {dest_array}, launch_dimensions, &b_)
3238                          .EmitLoop(GetIrNameFromLoc(op->getLoc())));
3239 
3240   return std::move(kernel_thunk);
3241 }
3242 
BuildFusedInitializerThunk(mlir::lmhlo::FusionOp fusion,int output_index)3243 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildFusedInitializerThunk(
3244     mlir::lmhlo::FusionOp fusion, int output_index) {
3245   auto reduce = mlir::dyn_cast_or_null<mlir::mhlo::ReduceOp>(
3246       fusion.getFusionRoots()[output_index]);
3247 
3248   TF_RET_CHECK(reduce);
3249   TF_RET_CHECK(reduce.getNumResults() == 1);
3250 
3251   mlir::Value init_value = reduce.init_values()[0];
3252   mlir::Value dest = fusion.getOutputBuffers()[output_index];
3253   TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> constant_init_thunk,
3254                       TryBuildConstantInitializerThunk(init_value, dest));
3255   if (constant_init_thunk) {
3256     return {std::move(constant_init_thunk)};
3257   }
3258 
3259   auto input_buffers = fusion.getInputBuffers();
3260 
3261   const Shape dest_shape = GetShape(dest);
3262   TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
3263                       CalculateLaunchDimensions(
3264                           dest_shape, ir_emitter_context_->gpu_device_info()));
3265   std::vector<llvm_ir::IrArray> ir_arrays;
3266   TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> kernel_thunk,
3267                       BuildKernelThunk(fusion, Thunk::ThunkInfo(), &ir_arrays,
3268                                        launch_dimensions));
3269 
3270   const llvm_ir::IrArray dest_array =
3271       ir_arrays[input_buffers.size() + output_index];
3272 
3273   const HloComputation* fused_computation =
3274       *GetOrCreateSubComputationFromRegion(&fusion.getRegion(),
3275                                            /*is_fusion=*/true);
3276 
3277   // If init_value was fused into this reduce we have to generate it first.
3278   GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
3279                                           GetNestedComputer());
3280 
3281   FusedIrEmitter fused_emitter(elemental_emitter);
3282   for (int i = 0; i < fused_computation->num_parameters(); i++) {
3283     fused_emitter.BindGenerator(
3284         *fused_computation->parameter_instruction(i),
3285         [this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
3286           return ir_arrays[i].EmitReadArrayElement(index, &b_);
3287         });
3288   }
3289   HloInstruction* instr = fused_computation->root_instruction();
3290   if (instr->opcode() != HloOpcode::kTuple) {
3291     CHECK_EQ(0, output_index);
3292   } else {
3293     instr = instr->mutable_operand(output_index);
3294   }
3295   TF_RET_CHECK(instr->shape().IsArray());
3296   TF_ASSIGN_OR_RETURN(auto generator,
3297                       fused_emitter.GetGenerator(*instr->operand(1)));
3298   TF_RETURN_IF_ERROR(
3299       ParallelLoopEmitter(generator, {dest_array}, launch_dimensions, &b_)
3300           .EmitLoop(GetIrNameFromLoc(fusion.getLoc())));
3301   return {std::move(kernel_thunk)};
3302 }
3303 
BuildWhileThunk(mlir::lmhlo::WhileOp while_op,const Thunk::ThunkInfo & thunk_info)3304 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildWhileThunk(
3305     mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info) {
3306   // Generate thunk sequence for while 'condition'.
3307   mlir::Region* condition = &while_op.getCond();
3308   TF_ASSIGN_OR_RETURN(
3309       auto ir_emitter_condition,
3310       IrEmitterUnnested::Create(hlo_module_config_, ir_emitter_context_));
3311 
3312   TF_RETURN_IF_ERROR(ir_emitter_condition->EmitLmhloRegion(condition));
3313 
3314   // Generate thunk sequence for while 'body'.
3315   mlir::Region* body = &while_op.getBody();
3316   TF_ASSIGN_OR_RETURN(
3317       auto ir_emitter_body,
3318       IrEmitterUnnested::Create(hlo_module_config_, ir_emitter_context_));
3319 
3320   TF_RETURN_IF_ERROR(ir_emitter_body->EmitLmhloRegion(body));
3321 
3322   // Extract the condition value from the last op (exlucidng the terminator op)
3323   // in the condition region.
3324   auto cond_result = GetHloOutputs(while_op);
3325   TF_RET_CHECK(cond_result.size() == 1);
3326   TF_ASSIGN_OR_RETURN(auto cond_result_slice,
3327                       GetAllocationSlice(cond_result[0]));
3328 
3329   return std::unique_ptr<Thunk>(
3330       new WhileThunk(thunk_info, cond_result_slice,
3331                      ir_emitter_condition->ConsumeThunkSequence(),
3332                      ir_emitter_body->ConsumeThunkSequence()));
3333 }
3334 
BuildForThunk(mlir::lmhlo::WhileOp while_op,const Thunk::ThunkInfo & thunk_info,const int64_t loop_limit)3335 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildForThunk(
3336     mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info,
3337     const int64_t loop_limit) {
3338   // Generate thunk sequence for while 'body' (will be used a For loop body).
3339   TF_ASSIGN_OR_RETURN(
3340       auto ir_emitter_body,
3341       IrEmitterUnnested::Create(hlo_module_config_, ir_emitter_context_));
3342   TF_RETURN_IF_ERROR(ir_emitter_body->EmitLmhloRegion(&while_op.getBody()));
3343 
3344   return std::unique_ptr<Thunk>(new ForThunk(
3345       thunk_info, loop_limit, ir_emitter_body->ConsumeThunkSequence()));
3346 }
3347 
EmitTargetElementLoop(const HloInstruction & hlo,const llvm_ir::ElementGenerator & body_emitter)3348 Status IrEmitterUnnested::EmitTargetElementLoop(
3349     const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter) {
3350   return InternalError("This should be unreachable");
3351 }
3352 
3353 // Gets the output offset as calculated from thread_id.x (to be applied to the
3354 // offset calculated from block_id and thread_id.y).
GetStartOffsetX(const TilingScheme & tiling_scheme,llvm::Value * thread_id_x,llvm::Type * index_ty,llvm::IRBuilder<> * b)3355 static llvm::Value* GetStartOffsetX(const TilingScheme& tiling_scheme,
3356                                     llvm::Value* thread_id_x,
3357                                     llvm::Type* index_ty,
3358                                     llvm::IRBuilder<>* b) {
3359   int64_t multiplier = tiling_scheme.GetIndexingOrder() == kStridedIndexingX
3360                            ? tiling_scheme.GetVectorSize()
3361                            : tiling_scheme.GetTileSizeFor(kDimX);
3362   return b->CreateMul(thread_id_x,
3363                       llvm::ConstantInt::get(index_ty, multiplier));
3364 }
3365 
3366 // Emits loop through the minor (X) dimension of a tile, starting at a given
3367 // offset.
3368 //
3369 // Rough pseudocode:
3370 //
3371 // Given: offset, callback
3372 //
3373 // for (int x = 0; x < x_tile_size / vector_size; ++x) {
3374 //   for (int i = 0; i < vector_size; ++i) {
3375 //      callback(offset + x * stride * vector_size + i);
3376 //   }
3377 // }
EmitXTileLoop(const IrEmitterUnnested::ThreadIdInfo & thread_id_info,const TilingScheme & tiling_scheme,bool check_x_tile_bounds,llvm::Value * start_offset_x,llvm::Value * y_loc,IrEmitterUnnested::ValueVector3 tile_dimensions,const IrArray::Index & source_idx,llvm::IRBuilder<> * b,const IrEmitterUnnested::EmitElementFunction * emit_elem_function)3378 static void EmitXTileLoop(
3379     const IrEmitterUnnested::ThreadIdInfo& thread_id_info,
3380     const TilingScheme& tiling_scheme, bool check_x_tile_bounds,
3381     llvm::Value* start_offset_x, llvm::Value* y_loc,
3382     IrEmitterUnnested::ValueVector3 tile_dimensions,
3383     const IrArray::Index& source_idx, llvm::IRBuilder<>* b,
3384     const IrEmitterUnnested::EmitElementFunction* emit_elem_function) {
3385   llvm::Type* index_ty = tile_dimensions[kDimX]->getType();
3386   KernelSupportLibrary ksl(b, llvm_ir::UnrollMode::kDefaultUnroll);
3387   auto constant = [&](int64_t val) {
3388     return llvm::ConstantInt::get(index_ty, val);
3389   };
3390 
3391   IrArray::Index source_idx_x_base = source_idx.AddOffsetToDim(y_loc, kDimY, b);
3392   int64_t vector_size = tiling_scheme.GetVectorSize();
3393   int64_t stride_x = tiling_scheme.GetIndexingOrder() == kLinearIndexingX
3394                          ? 1
3395                          : tiling_scheme.GetNumThreadsFor(kDimX);
3396   KernelSupportLibrary unrolled_ksl(b, llvm_ir::UnrollMode::kFullyUnroll);
3397   unrolled_ksl.For(
3398       "tile_loop",
3399       /*start=*/constant(0),
3400       /*end=*/constant(tiling_scheme.GetTileSizeFor(kDimX) / vector_size),
3401       /*step=*/1, [&](llvm::Value* x) {
3402         for (int64_t i = 0; i < vector_size; i++) {
3403           llvm::Value* linear_index =
3404               b->CreateAdd(b->CreateMul(x, constant(vector_size)), constant(i));
3405           llvm::Value* x_loc = b->CreateAdd(
3406               b->CreateAdd(b->CreateMul(x, constant(stride_x * vector_size)),
3407                            constant(i)),
3408               start_offset_x, "x_loc");
3409           IrArray::Index source_idx_x = source_idx_x_base.AddOffsetToDim(
3410               b->CreateAdd(b->CreateMul(x, constant(stride_x * vector_size)),
3411                            constant(i)),
3412               kDimX, b);
3413           auto emit_element = [&] {
3414             return (*emit_elem_function)(thread_id_info, source_idx_x, y_loc,
3415                                          x_loc, linear_index);
3416           };
3417           if (check_x_tile_bounds) {
3418             ksl.If("x_in_tile", b->CreateICmpULT(x_loc, tile_dimensions[kDimX]),
3419                    emit_element);
3420           } else {
3421             emit_element();
3422           }
3423         }
3424       });
3425 }
3426 
EmitTile(const TilingScheme & tiling_scheme,const IrArray::Index & tile_origin_index,const ThreadIdInfo & thread_id_info,ValueVector3 tile_dimensions,const IrEmitterUnnested::EmitElementFunction & emit_elem_function)3427 void IrEmitterUnnested::EmitTile(
3428     const TilingScheme& tiling_scheme, const IrArray::Index& tile_origin_index,
3429     const ThreadIdInfo& thread_id_info, ValueVector3 tile_dimensions,
3430     const IrEmitterUnnested::EmitElementFunction& emit_elem_function) {
3431   llvm::Type* index_ty = tile_dimensions[kDimY]->getType();
3432   auto constant = [&](int64_t val) {
3433     return llvm::ConstantInt::get(index_ty, val);
3434   };
3435   llvm::Value* num_threads_y = constant(tiling_scheme.GetNumThreadsFor(kDimY));
3436   llvm::Value* start_offset_x =
3437       GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_x, index_ty, &b_);
3438 
3439   KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
3440   IrArray::Index source_idx =
3441       tile_origin_index.AddOffsetToDim(start_offset_x, kDimX, &b_);
3442 
3443   ksl.For(
3444       "y_in_tile",
3445       /*start=*/thread_id_info.thread_id_y,
3446       /*end=*/
3447       tile_dimensions[kDimY],
3448       /*step=*/num_threads_y, [&](llvm::Value* y_loc) {
3449         auto unroll_inner_tile_loop = [&](bool check_x_tile_bounds) {
3450           return EmitXTileLoop(thread_id_info, tiling_scheme,
3451                                check_x_tile_bounds, start_offset_x, y_loc,
3452                                tile_dimensions, source_idx, &b_,
3453                                &emit_elem_function);
3454         };
3455 
3456         // Only take this path when we unroll in a way vectorizable by
3457         // LLVM. Special case when the tile doesn't fit completely for even
3458         // row size. For odd row size every other row isn't aligned to the
3459         // vectorized size, so it can't be vectorized by LLVM.
3460         if (tiling_scheme.GetIndexingOrder() == kStridedIndexingX) {
3461           ksl.If(
3462               "is_full_tile",
3463               b_.CreateICmpEQ(
3464                   constant(tiling_scheme.GetBlockTileSizeFor(kDimX)),
3465                   tile_dimensions[kDimX]),
3466               [&] { unroll_inner_tile_loop(/*check_x_tile_bounds=*/false); },
3467               [&] { unroll_inner_tile_loop(/*check_x_tile_bounds=*/true); });
3468         } else {
3469           unroll_inner_tile_loop(/*check_x_tile_bounds=*/true);
3470         }
3471       });
3472 }
3473 
GetUnnormalizedIndex(const IrArray::Index & normalized_shape_index,const Shape & unnormalized_shape,llvm::IRBuilder<> * b_,const TilingScheme & tiling_scheme)3474 static IrArray::Index GetUnnormalizedIndex(
3475     const IrArray::Index& normalized_shape_index,
3476     const Shape& unnormalized_shape, llvm::IRBuilder<>* b_,
3477     const TilingScheme& tiling_scheme) {
3478   DCHECK_EQ(normalized_shape_index.size(), 3);
3479   // If the normalization only add a new dimensions of size 1,
3480   // generate simpler indexing. LLVM doesn't always simplify the more
3481   // complicated indexing and this prevents it from vectorizing some
3482   // cases. We do this only for major_to_minor memory layout.
3483   if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() &&
3484       unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[1] &&
3485       unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[2] &&
3486       unnormalized_shape.layout().minor_to_major(1) == 0) {
3487     CHECK_EQ(normalized_shape_index.dims()[0], 1);
3488     auto multidim = normalized_shape_index.multidim();
3489     return IrArray::Index({multidim[1], multidim[2]}, unnormalized_shape,
3490                           normalized_shape_index.GetType());
3491   }
3492   if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() &&
3493       unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[2] &&
3494       unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[1] &&
3495       unnormalized_shape.layout().minor_to_major(1) == 1) {
3496     CHECK_EQ(normalized_shape_index.dims()[0], 1);
3497     auto multidim = normalized_shape_index.multidim();
3498     return IrArray::Index({multidim[2], multidim[1]}, unnormalized_shape,
3499                           normalized_shape_index.GetType());
3500   }
3501   llvm::Value* linear =
3502       normalized_shape_index.Linearize(tiling_scheme.GetDimsInElems(), b_);
3503   return IrArray::Index(linear, unnormalized_shape, b_);
3504 }
3505 
3506 // Emits code to process a tensor element in a tile for the given kLoop fusion
3507 // HLO containing parameters that are 0-2-1 transpose of its outputs.
3508 //
3509 // index: The index for the first output element in the normalized tensor, that
3510 //   is the resulting tensor after collapsing contiguous dimensions that play
3511 //   the same role in the transpose.
3512 // kernel_info: Other information to support the kernel code generation.
EmitTileElementForTranspose(const ThreadIdInfo & thread_id_info,mlir::lmhlo::FusionOp fusion,absl::Span<const llvm_ir::IrArray> operand_arrays,absl::Span<const llvm_ir::IrArray> output_arrays,const llvm_ir::IrArray::Index & index,const TilingScheme & tiling_scheme,llvm::Value * y_loc,llvm::Value * x_loc,absl::Span<llvm::Value * const> param_shmem_buffers)3513 void IrEmitterUnnested::EmitTileElementForTranspose(
3514     const ThreadIdInfo& thread_id_info, mlir::lmhlo::FusionOp fusion,
3515     absl::Span<const llvm_ir::IrArray> operand_arrays,
3516     absl::Span<const llvm_ir::IrArray> output_arrays,
3517     const llvm_ir::IrArray::Index& index, const TilingScheme& tiling_scheme,
3518     llvm::Value* y_loc, llvm::Value* x_loc,
3519     absl::Span<llvm::Value* const> param_shmem_buffers) {
3520   const HloComputation* fused_computation =
3521       *GetOrCreateSubComputationFromRegion(&fusion.getRegion(),
3522                                            /*is_fusion=*/true);
3523   GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
3524                                      GetNestedComputer());
3525   FusedIrEmitter fused_emitter(elem_emitter);
3526   for (int i = 0; i < operand_arrays.size(); i++) {
3527     llvm_ir::ElementGenerator gen;
3528     if (auto* param_tile_buffer =
3529             llvm::cast_or_null<llvm::GlobalVariable>(param_shmem_buffers[i])) {
3530       gen = [this, param_tile_buffer, x_loc, y_loc,
3531              thread_id_info](llvm_ir::IrArray::Index index) {
3532         // TODO(jlebar): Add AA metadata to this load.  Tile buffers are
3533         // global variables, so LLVM's points-to analysis doesn't help us
3534         // much.  And we want the AA info to be present before address
3535         // spaces are inferred (which is pretty late in the pipeline), so
3536         // even if we had address-space-based AA in LLVM, it wouldn't help
3537         // us much here.
3538         std::vector<llvm::Value*> idx = {x_loc, y_loc};
3539         auto gep =
3540             thread_id_info.GEPIntoSharedMemory(&b_, param_tile_buffer, idx);
3541         auto type =
3542             thread_id_info.GEPIntoSharedMemoryType(param_tile_buffer, idx);
3543         return Load(type, gep, "tiled_buffer");
3544       };
3545     } else {
3546       auto array = operand_arrays[i];
3547       auto name = fused_computation->parameter_instruction(i)->name();
3548       gen = [this, array, name](const llvm_ir::IrArray::Index& index) {
3549         return array.EmitReadArrayElement(index, &b_, name);
3550       };
3551     }
3552     fused_emitter.BindGenerator(*fused_computation->parameter_instruction(i),
3553                                 std::move(gen));
3554   }
3555   IrArray::Index untiled_index = GetUnnormalizedIndex(
3556       index, output_arrays[0].GetShape(), &b_, tiling_scheme);
3557   llvm_ir::ElementGenerator output_generator =
3558       *fused_emitter.GetGenerator(*fused_computation->root_instruction());
3559   llvm::Value* output_value = output_generator(untiled_index).ValueOrDie();
3560   if (output_arrays.size() > 1) {
3561     DCHECK(output_value->getType()->isStructTy());
3562     DCHECK_EQ(output_value->getType()->getStructNumElements(),
3563               output_arrays.size());
3564     for (int64_t i = 0; i < output_arrays.size(); ++i) {
3565       output_arrays[i].EmitWriteArrayElement(
3566           untiled_index, ExtractValue(output_value, i), &b_);
3567     }
3568   } else {
3569     output_arrays[0].EmitWriteArrayElement(untiled_index, output_value, &b_);
3570   }
3571 }
3572 
GetNumOutputs(const Shape & shape)3573 static int GetNumOutputs(const Shape& shape) {
3574   if (shape.IsTuple()) {
3575     return shape.tuple_shapes_size();
3576   }
3577   return 1;
3578 }
3579 
GenerateReductionCodegenState(mlir::lmhlo::FusionOp fusion,const ReductionCodegenInfo & reduction_info,absl::Span<const HloReduceInstruction * const> reduce_instr_index_group,FusedIrEmitter & fused_emitter)3580 ReductionCodegenState IrEmitterUnnested::GenerateReductionCodegenState(
3581     mlir::lmhlo::FusionOp fusion, const ReductionCodegenInfo& reduction_info,
3582     absl::Span<const HloReduceInstruction* const> reduce_instr_index_group,
3583     FusedIrEmitter& fused_emitter) {
3584   ReductionCodegenState reduction_codegen_state(reduction_info);
3585   VLOG(10) << "Emit prologue for reduction: " << MlirToString(fusion);
3586 
3587   for (const HloReduceInstruction* reduce_hlo : reduce_instr_index_group) {
3588     int num_partial_results = reduction_codegen_state.GetNumPartialResults();
3589     for (int op_result_idx = 0;
3590          op_result_idx < GetNumOutputs(reduce_hlo->shape()); op_result_idx++) {
3591       Shape result_shape = reduce_hlo->shape().IsTuple()
3592                                ? reduce_hlo->shape().tuple_shapes(op_result_idx)
3593                                : reduce_hlo->shape();
3594 
3595       llvm::Type* element_type =
3596           llvm_ir::PrimitiveTypeToIrType(result_shape.element_type(), module_);
3597       llvm::AllocaInst* reduction_input_address =
3598           llvm_ir::EmitAllocaAtFunctionEntry(element_type,
3599                                              "reduction_input_address", &b_);
3600 
3601       llvm::AllocaInst* partial_result_address =
3602           llvm_ir::EmitAllocaAtFunctionEntryWithCount(
3603               element_type, /*element_count=*/b_.getInt32(num_partial_results),
3604               "partial_reduction_result", &b_);
3605 
3606       const HloInstruction* init_value =
3607           reduce_hlo->init_values()[op_result_idx];
3608 
3609       // Initialize the partial result with the initial value of the reduction.
3610       llvm::Value* init_ir_value = (*fused_emitter.GetGenerator(*init_value))(
3611                                        IrArray::Index(b_.getInt32Ty()))
3612                                        .ValueOrDie();
3613 
3614       for (int i = 0; i < num_partial_results; ++i) {
3615         b_.CreateStore(init_ir_value,
3616                        InBoundsGEP(partial_result_address->getAllocatedType(),
3617                                    partial_result_address, {b_.getInt32(i)}));
3618       }
3619 
3620       const TilingScheme& tiling_scheme =
3621           reduction_codegen_state.GetTilingScheme();
3622       int64_t num_threads_x = tiling_scheme.GetNumThreadsFor(kDimX);
3623       llvm::GlobalVariable* shared_cache = [&]() -> llvm::GlobalVariable* {
3624         if (reduction_codegen_state.IsRowReduction()) {
3625           // Multi-row reductions do not use shared memory.
3626           if (RowReductionGetRowsPerWarp(tiling_scheme.GetDimsInElems()[2]) >
3627               1) {
3628             return nullptr;
3629           }
3630           // Allocate __shared__
3631           // cache[num_partial_results][num_warps][scaling_factor].
3632           CHECK_EQ(tiling_scheme.GetNumThreadsPerBlock() % WarpSize(), 0);
3633           int num_warps = tiling_scheme.GetNumThreadsPerBlock() / WarpSize();
3634           return AllocateShared(tiling_scheme, element_type,
3635                                 {num_partial_results, num_warps},
3636                                 "shared_cache");
3637         } else {
3638           // Allocate __shared__
3639           // cache[num_partial_results][num_threads][num_threads + 1], where
3640           // num_threads == num_threads_x == num_threads_y.  The "+1" is used to
3641           // avoid bank conflicts.
3642           CHECK_EQ(num_threads_x, tiling_scheme.GetNumThreadsFor(kDimY));
3643           return AllocateShared(
3644               tiling_scheme, element_type,
3645               {num_partial_results, num_threads_x, num_threads_x + 1},
3646               "shared_cache");
3647         }
3648       }();
3649 
3650       llvm_ir::ElementGenerator input_gen =
3651           *fused_emitter.GetGenerator(*reduce_hlo->inputs()[op_result_idx]);
3652       reduction_codegen_state.SetCalculationStateFor(
3653           {shared_cache, init_ir_value, partial_result_address,
3654            reduction_input_address, input_gen},
3655           reduce_hlo, op_result_idx);
3656     }
3657   }
3658 
3659   return reduction_codegen_state;
3660 }
3661 
EmitFullWarpShuffleDownLoopForReduce(const HloComputation * reducer,absl::Span<std::pair<llvm::Value * const,llvm::Type * const>> partial_result_addresses,int threads_per_block,int num_results_per_warp)3662 void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForReduce(
3663     const HloComputation* reducer,
3664     absl::Span<std::pair<llvm::Value* const, llvm::Type* const>>
3665         partial_result_addresses,
3666     int threads_per_block, int num_results_per_warp) {
3667   // This only works when the block size is a multiple of 32 threads.
3668 
3669   // We check this here as a mistake in the number of threads per
3670   // block is very hard to detect.
3671   CHECK_EQ(threads_per_block % 32, 0);
3672   CHECK_EQ(WarpSize() % num_results_per_warp, 0);
3673 
3674   for (int distance = 16 / num_results_per_warp; distance >= 1; distance /= 2) {
3675     absl::InlinedVector<llvm::Value*, 2> reduction_params;
3676 
3677     for (auto acc : partial_result_addresses) {
3678       reduction_params.push_back(acc.first);
3679     }
3680 
3681     for (auto i : partial_result_addresses) {
3682       llvm::Value* partial_result_address = i.first;
3683       llvm::Type* element_type = i.second;
3684 
3685       int bit_width = llvm_ir::GetSizeInBits(element_type);
3686       llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry(
3687           element_type, "result_from_other_lane", &b_);
3688 
3689       reduction_params.push_back(result_from_other_lane);
3690 
3691       // Bitcast cannot be applied to aggregate types (even packed ones), so
3692       // we bitcast addresses of load/store to intN* of the same bit-width.
3693       llvm::Type* shuffled_value_type =
3694           element_type->isStructTy() ? b_.getIntNTy(bit_width) : element_type;
3695       auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) {
3696         return b_.CreatePointerBitCastOrAddrSpaceCast(
3697             ptr, shuffled_value_type->getPointerTo());
3698       };
3699 
3700       llvm::Value* partial_result =
3701           b_.CreateLoad(shuffled_value_type,
3702                         convert_pointer_for_shuffle(partial_result_address),
3703                         "partial_reduction_result");
3704       b_.CreateStore(
3705           EmitFullWarpShuffleDown(partial_result, b_.getInt32(distance), &b_),
3706           convert_pointer_for_shuffle(result_from_other_lane));
3707     }
3708 
3709     StatusOr<std::vector<llvm::Value*>> returned_scalars =
3710         ComputeNestedElementFromAddrs(*reducer, reduction_params);
3711     TF_CHECK_OK(returned_scalars.status());
3712 
3713     for (int i = 0; i < returned_scalars->size(); i++) {
3714       b_.CreateStore(/*Val=*/returned_scalars->at(i),
3715                      /*Ptr=*/partial_result_addresses[i].first);
3716     }
3717   }
3718 }
3719 
GetOutputAddressForReduction(int partial_result_idx,llvm::Type * index_ty,const ReductionCodegenState & reduction_codegen_state,const TilingKernelInfo & tiling_kernel_info,const IrEmitterUnnested::ReductionOutputMap & output_arrays,const HloReduceInstruction * reduction,int output_idx)3720 llvm::Value* IrEmitterUnnested::GetOutputAddressForReduction(
3721     int partial_result_idx, llvm::Type* index_ty,
3722     const ReductionCodegenState& reduction_codegen_state,
3723     const TilingKernelInfo& tiling_kernel_info,
3724     const IrEmitterUnnested::ReductionOutputMap& output_arrays,
3725     const HloReduceInstruction* reduction, int output_idx) {
3726   auto constant = [&](uint64_t c) -> llvm::Constant* {
3727     return llvm::ConstantInt::get(index_ty, c);
3728   };
3729 
3730   const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme();
3731   const ThreadIdInfo& thread_id_info = tiling_kernel_info.thread_id_info;
3732 
3733   IrArray::Index start_offset = [&] {
3734     llvm::Value* x_loc = thread_id_info.thread_id_x;
3735     llvm::Value* y_loc = thread_id_info.thread_id_y;
3736     if (!reduction_codegen_state.IsRowReduction()) {
3737       std::swap(x_loc, y_loc);
3738     }
3739     llvm::Value* start_offset_x =
3740         GetStartOffsetX(tiling_scheme, x_loc, index_ty, &b_);
3741     return tiling_kernel_info.tile_origin.AddOffsetToDim(y_loc, kDimY, &b_)
3742         .AddOffsetToDim(start_offset_x, kDimX, &b_);
3743   }();
3744 
3745   const IrArray& output_array = output_arrays.at(reduction)[output_idx];
3746   const Shape& operand_shape = reduction->inputs()[output_idx]->shape();
3747   Shape reduction_kept_element_shape =
3748       ShapeUtil::DeleteDimensions(reduction->dimensions(), operand_shape);
3749 
3750   // Given the IrArray index of a reduction input, returns the linear address of
3751   // the reduction output as if the reduction were going to keep the input shape
3752   // with the dimensions being reduced moved.
3753   llvm::Value* untransposed_output_linear_address = [&] {
3754     const llvm_ir::IrArray::Index index =
3755         start_offset.AddOffsetToDim(constant(partial_result_idx), kDimX, &b_);
3756     if (reduction_codegen_state.IsRowReduction()) {
3757       // For row-reduction, y-coordinate determines which row we write into.
3758       return index[kDimY];
3759     }
3760     // For column reduction, we get the transposed address.
3761     absl::Span<const int64_t> dims_in_elem = tiling_scheme.GetDimsInElems();
3762     llvm::Value* x_dim_size =
3763         index.GetConstantWithIndexType(dims_in_elem[kDimX]);
3764     llvm::Value* x_block_offset = b_.CreateMul(index[kDimZ], x_dim_size);
3765     return b_.CreateAdd(x_block_offset, index[kDimX]);
3766   }();
3767 
3768   // A reduction is allowed to transpose its output.  For example, suppose
3769   // we are reducing the second dimension of f32[10,20,30]{3,2,1}.  We are
3770   // allowed to produce as output either f32[10,30]{1,0} (no transpose) or
3771   // f32[10,30]{0,1} (transposing the two output dims).
3772   //
3773   // At this point in the function we have a "partial sum" of input elements
3774   // (stored in partial_result_addresses), and we need to accumulate it into
3775   // the correct output element.
3776   IrArray::Index element_index(
3777       /*linear=*/untransposed_output_linear_address,
3778       reduction_kept_element_shape, &b_);
3779   IrArray::Index output_index(element_index.multidim(), output_array.GetShape(),
3780                               element_index.GetType());
3781 
3782   return output_array.EmitArrayElementAddress(output_index, &b_,
3783                                               "output_element_address");
3784 }
3785 
EmitBlockId(int32_t num_blocks,llvm::Type * index_ty)3786 llvm::Value* IrEmitterUnnested::EmitBlockId(int32_t num_blocks,
3787                                             llvm::Type* index_ty) {
3788   llvm::Value* block_id = gpu::EmitCallToTargetIntrinsic(
3789       gpu::TargetIntrinsicID::kBlockIdx, {}, {}, &b_);
3790   if (num_blocks != 0) {
3791     llvm_ir::AddRangeMetadata(0, num_blocks,
3792                               llvm::cast<llvm::Instruction>(block_id));
3793   }
3794   llvm::Value* linear_block_id =
3795       b_.CreateIntCast(block_id, index_ty, /*isSigned=*/true, "block.id.x");
3796   return linear_block_id;
3797 }
3798 
EmitPrintfWithThreadId(absl::string_view fmt,absl::Span<llvm::Value * const> arguments,std::optional<int64_t> thread_id_filter,std::optional<int64_t> block_id_filter)3799 void IrEmitterUnnested::EmitPrintfWithThreadId(
3800     absl::string_view fmt, absl::Span<llvm::Value* const> arguments,
3801     std::optional<int64_t> thread_id_filter,
3802     std::optional<int64_t> block_id_filter) {
3803   llvm::Value* thread_id = EmitThreadId(
3804       /*threads_per_block=*/1024, b_.getInt32Ty());
3805   llvm::Value* block_id = EmitBlockId(0, b_.getInt32Ty());
3806   std::vector<llvm::Value*> updated_arguments = {thread_id, block_id};
3807   updated_arguments.insert(updated_arguments.end(), arguments.begin(),
3808                            arguments.end());
3809   llvm::Value* constraint = b_.getTrue();
3810   if (thread_id_filter) {
3811     constraint = b_.CreateAnd(
3812         constraint, b_.CreateICmpEQ(thread_id, b_.getInt32(*thread_id_filter)));
3813   }
3814   if (block_id_filter) {
3815     constraint = b_.CreateAnd(
3816         constraint, b_.CreateICmpEQ(block_id, b_.getInt32(*block_id_filter)));
3817   }
3818   KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
3819   ksl.If(constraint, [&] {
3820     xla::gpu::EmitPrintf(absl::StrCat("[TID=%d,BID=%d] ", fmt, "\n"),
3821                          updated_arguments, &b_);
3822   });
3823 }
3824 
CastSharedToGlobal(llvm::Value * input,llvm::Type * element_type,llvm::Twine name)3825 llvm::Value* IrEmitterUnnested::CastSharedToGlobal(llvm::Value* input,
3826                                                    llvm::Type* element_type,
3827                                                    llvm::Twine name) {
3828   return b_.CreateAddrSpaceCast(input,
3829                                 llvm::PointerType::get(element_type,
3830                                                        /*AddressSpace=*/0),
3831                                 name);
3832 }
3833 
EmitReductionOutputForRowReduction(const TilingKernelInfo & tiling_kernel_info,const ReductionCodegenState & reduction_codegen_state,llvm::Type * index_ty,const ReductionOutputMap & output_arrays,const HloReduceInstruction * reduction,int partial_result_idx)3834 void IrEmitterUnnested::EmitReductionOutputForRowReduction(
3835     const TilingKernelInfo& tiling_kernel_info,
3836     const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty,
3837     const ReductionOutputMap& output_arrays,
3838     const HloReduceInstruction* reduction, int partial_result_idx) {
3839   const HloComputation* reducer = reduction->to_apply();
3840   const auto& thread_id_info = tiling_kernel_info.thread_id_info;
3841   auto constant = [&](uint64_t c) -> llvm::Constant* {
3842     return llvm::ConstantInt::get(index_ty, c);
3843   };
3844   auto is_zero = [&](llvm::Value* value) {
3845     return b_.CreateICmpEQ(value, constant(0));
3846   };
3847 
3848   int num_outputs = reducer->num_parameters() / 2;
3849   const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme();
3850   absl::InlinedVector<std::pair<llvm::Value* const, llvm::Type* const>, 2>
3851       current_outputs;
3852   for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
3853     const ReductionCodegenState::ReductionCalculationState& state =
3854         reduction_codegen_state.GetCalculationStateFor(reduction, output_idx);
3855     current_outputs.push_back(
3856         {InBoundsGEP(state.partial_result_address->getAllocatedType(),
3857                      state.partial_result_address,
3858                      {constant(partial_result_idx)}, "current_output"),
3859          state.partial_result_address->getAllocatedType()});
3860   }
3861 
3862   int reduced_dimension_size = tiling_scheme.GetDimsInElems()[2];
3863   int num_rows_per_warp = RowReductionGetRowsPerWarp(reduced_dimension_size);
3864   EmitFullWarpShuffleDownLoopForReduce(
3865       reducer, absl::MakeSpan(current_outputs),
3866       tiling_scheme.GetNumThreadsPerBlockPhysical(), num_rows_per_warp);
3867 
3868   KernelSupportLibrary ksl(&b_);
3869   llvm::Value* warp_id =
3870       b_.CreateUDiv(thread_id_info.thread_id_x, constant(WarpSize()));
3871 
3872   auto emit_write_output =
3873       [&](llvm::Value* write_condition,
3874           const absl::InlinedVector<
3875               std::pair<llvm::Value* const, llvm::Type* const>, 2>& values) {
3876         ksl.If("reduction_write_output", write_condition, [&] {
3877           for (int oidx = 0; oidx < num_outputs; oidx++) {
3878             llvm::Value* output_address = GetOutputAddressForReduction(
3879                 partial_result_idx, index_ty, reduction_codegen_state,
3880                 tiling_kernel_info, output_arrays, reduction, oidx);
3881 
3882             if (reduction_codegen_state.IsRaceFree()) {
3883               Store(Load(values[oidx].second, values[oidx].first, "output"),
3884                     output_address);
3885             } else {
3886               CHECK_EQ(num_outputs, 1);
3887               TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
3888                   *reducer, output_address, values[oidx].first,
3889                   values[oidx].second));
3890             }
3891           }
3892         });
3893       };
3894 
3895   if (num_rows_per_warp > 1) {
3896     llvm::Value* is_writing_thread = is_zero(b_.CreateAnd(
3897         thread_id_info.thread_id_x, constant(reduced_dimension_size - 1)));
3898     emit_write_output(is_writing_thread, current_outputs);
3899     return;
3900   }
3901 
3902   ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] {
3903     for (int oidx = 0; oidx < num_outputs; oidx++) {
3904       const ReductionCodegenState::ReductionCalculationState& state =
3905           reduction_codegen_state.GetCalculationStateFor(reduction, oidx);
3906       llvm::Value* shmem_output_addr = thread_id_info.GEPIntoSharedMemory(
3907           &b_, state.shared_cache, {constant(partial_result_idx), warp_id});
3908       Store(Load(current_outputs[oidx].second, current_outputs[oidx].first),
3909             shmem_output_addr);
3910     }
3911   });
3912 
3913   // TODO(cheshire): Don't we want to sync it once for everything in the
3914   // output? Not once per each?
3915   EmitSyncThreads();
3916   ksl.If("inter_warp_reduce", is_zero(warp_id), [&] {
3917     absl::InlinedVector<std::pair<llvm::Value* const, llvm::Type* const>, 2>
3918         selected_values;
3919     for (int oidx = 0; oidx < num_outputs; oidx++) {
3920       const ReductionCodegenState::ReductionCalculationState& state =
3921           reduction_codegen_state.GetCalculationStateFor(reduction, oidx);
3922       llvm::Value* block_accum_addr = thread_id_info.GEPIntoSharedMemory(
3923           &b_, state.shared_cache,
3924           {constant(partial_result_idx), thread_id_info.lane_id});
3925 
3926       llvm::Type* element_type =
3927           state.partial_result_address->getAllocatedType();
3928 
3929       /* Insure initial value address is in generic, not scratch. */
3930       llvm::Value* initial_value_addr =
3931           CastSharedToGlobal(llvm_ir::EmitAllocaAtFunctionEntry(
3932                                  element_type, "initial_value_addr", &b_),
3933                              element_type);
3934       b_.CreateStore(state.initial_value, initial_value_addr);
3935 
3936       llvm::Value* warp_exists = b_.CreateICmpULT(
3937           thread_id_info.thread_id_x,
3938           constant(tiling_scheme.GetNumThreadsFor(kDimX) / WarpSize()));
3939 
3940       llvm::Value* selected_value =
3941           b_.CreateSelect(warp_exists, block_accum_addr, initial_value_addr);
3942 
3943       selected_values.push_back({selected_value, element_type});
3944     }
3945 
3946     // If only one warp is present in the block, then we don't need inter-warp
3947     // reduction.
3948     // TODO(b/241414088) If only warp is present, then inter-warp communication
3949     // using shared memory and synchronization using barrier is also unnecessary
3950     // and should be removed.
3951     if (tiling_scheme.GetNumThreadsPerBlock() > WarpSize()) {
3952       EmitFullWarpShuffleDownLoopForReduce(
3953           reducer, absl::MakeSpan(selected_values),
3954           tiling_scheme.GetNumThreadsPerBlock());
3955     }
3956 
3957     emit_write_output(is_zero(thread_id_info.thread_id_x), selected_values);
3958   });
3959 }
3960 
EmitReductionOutputForColumnReduction(const TilingKernelInfo & tiling_kernel_info,const ReductionCodegenState & reduction_codegen_state,llvm::Type * index_ty,const ReductionOutputMap & output_arrays,const HloReduceInstruction * reduction,int partial_result_idx)3961 void IrEmitterUnnested::EmitReductionOutputForColumnReduction(
3962     const TilingKernelInfo& tiling_kernel_info,
3963     const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty,
3964     const ReductionOutputMap& output_arrays,
3965     const HloReduceInstruction* reduction, int partial_result_idx) {
3966   KernelSupportLibrary ksl(&b_);
3967   const HloComputation* reducer = reduction->to_apply();
3968   const auto& thread_id_info = tiling_kernel_info.thread_id_info;
3969 
3970   auto constant = [&](uint64_t c) -> llvm::Constant* {
3971     return llvm::ConstantInt::get(index_ty, c);
3972   };
3973   auto is_zero = [&](llvm::Value* value) {
3974     return b_.CreateICmpEQ(value, constant(0));
3975   };
3976   const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme();
3977   int num_outputs = reducer->num_parameters() / 2;
3978 
3979   // Store the transpose in shared memory.
3980   for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
3981     const ReductionCodegenState::ReductionCalculationState& state =
3982         reduction_codegen_state.GetCalculationStateFor(reduction, output_idx);
3983     llvm::GlobalVariable* shared_cache = state.shared_cache;
3984     llvm::AddrSpaceCastInst* shmem_output_addr =
3985         llvm::cast<llvm::AddrSpaceCastInst>(thread_id_info.GEPIntoSharedMemory(
3986             &b_, shared_cache,
3987             {constant(partial_result_idx), thread_id_info.thread_id_x,
3988              thread_id_info.thread_id_y},
3989             "shmem_output_address"));
3990     llvm::Value* current_output =
3991         InBoundsGEP(state.partial_result_address->getAllocatedType(),
3992                     state.partial_result_address,
3993                     {constant(partial_result_idx)}, "current_output");
3994 
3995     llvm::Value* current_output_value =
3996         Load(state.partial_result_address->getAllocatedType(), current_output);
3997     b_.CreateStore(current_output_value, shmem_output_addr);
3998   }
3999 
4000   EmitSyncThreads();
4001 
4002   // Get transposed element from shared memory.
4003   absl::InlinedVector<std::pair<llvm::Value* const, llvm::Type* const>, 1>
4004       shmem_transposed_addrs;
4005 
4006   for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
4007     const ReductionCodegenState::ReductionCalculationState& state =
4008         reduction_codegen_state.GetCalculationStateFor(reduction, output_idx);
4009     llvm::AddrSpaceCastInst* shmem_transposed_addr =
4010         llvm::cast<llvm::AddrSpaceCastInst>(thread_id_info.GEPIntoSharedMemory(
4011             &b_, state.shared_cache,
4012             {constant(partial_result_idx), thread_id_info.thread_id_y,
4013              thread_id_info.thread_id_x},
4014             "shmem_transposed_addr"));
4015     shmem_transposed_addrs.push_back(
4016         {shmem_transposed_addr, llvm::cast<llvm::GetElementPtrInst>(
4017                                     shmem_transposed_addr->getPointerOperand())
4018                                     ->getResultElementType()});
4019   }
4020 
4021   EmitFullWarpShuffleDownLoopForReduce(reducer,
4022                                        absl::MakeSpan(shmem_transposed_addrs),
4023                                        tiling_scheme.GetNumThreadsPerBlock());
4024 
4025   // Some warps in the block are completely outside of the bound of the
4026   // tensor, so they should not write any output at all.
4027   llvm::Value* has_output = b_.CreateAnd(
4028       b_.CreateICmpULT(
4029           GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_y, index_ty,
4030                           &b_),
4031           tiling_kernel_info.output_tile_bounds[kDimX]),
4032       b_.CreateICmpULT(thread_id_info.thread_id_x,
4033                        tiling_kernel_info.output_tile_bounds[kDimY]));
4034 
4035   ksl.If("reduction_write_output",
4036          b_.CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] {
4037            for (int oidx = 0; oidx < num_outputs; oidx++) {
4038              llvm::Value* output_address = GetOutputAddressForReduction(
4039                  partial_result_idx, index_ty, reduction_codegen_state,
4040                  tiling_kernel_info, output_arrays, reduction, oidx);
4041              if (reduction_codegen_state.IsRaceFree()) {
4042                Store(Load(shmem_transposed_addrs[oidx].second,
4043                           shmem_transposed_addrs[oidx].first, "output_value"),
4044                      output_address);
4045              } else {
4046                CHECK_EQ(num_outputs, 1);
4047                TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
4048                    *reducer, output_address, shmem_transposed_addrs[oidx].first,
4049                    shmem_transposed_addrs[oidx].second));
4050              }
4051            }
4052          });
4053 }
4054 
EmitThreadId(int64_t threads_per_block,llvm::Type * index_ty)4055 llvm::Value* IrEmitterUnnested::EmitThreadId(int64_t threads_per_block,
4056                                              llvm::Type* index_ty) {
4057   // Calculate (y, x) coordinates respectively in the 2D view of thread block,
4058   // defined by (num_thread_y, num_thread_x) from thread_id.
4059   llvm::CallInst* thread_id_raw = gpu::EmitCallToTargetIntrinsic(
4060       gpu::TargetIntrinsicID::kThreadIdx, {}, {}, &b_);
4061   llvm_ir::AddRangeMetadata(0, threads_per_block, thread_id_raw);
4062   return b_.CreateIntCast(thread_id_raw, index_ty,
4063                           /*isSigned=*/true, "thread.id.x");
4064 }
4065 
EmitThreadIdInfo(const TilingScheme & tiling_scheme,llvm::Type * index_ty)4066 StatusOr<IrEmitterUnnested::ThreadIdInfo> IrEmitterUnnested::EmitThreadIdInfo(
4067     const TilingScheme& tiling_scheme, llvm::Type* index_ty) {
4068   auto constant = [&](uint64_t c) -> llvm::Constant* {
4069     return llvm::ConstantInt::get(index_ty, c);
4070   };
4071   llvm::Value* thread_id_physical =
4072       EmitThreadId(tiling_scheme.GetNumThreadsPerBlockPhysical(), index_ty);
4073   int64_t num_blocks = tiling_scheme.GetNumberOfBlocksPhysical();
4074   if (num_blocks > (int64_t)std::numeric_limits<uint32_t>::max()) {
4075     return FailedPrecondition(
4076         "Number of physical blocks (%d) does not fit in an i32 in tiling "
4077         "scheme: %s",
4078         num_blocks, tiling_scheme.ToString());
4079   }
4080   llvm::Value* block_id_physical = EmitBlockId(num_blocks, index_ty);
4081 
4082   // Wait this will break coalescing.
4083   llvm::Value* thread_id_logical = b_.CreateURem(
4084       thread_id_physical, constant(tiling_scheme.GetNumThreadsPerBlock()));
4085   llvm::Value* scaling = b_.CreateUDiv(
4086       thread_id_physical, constant(tiling_scheme.GetNumThreadsPerBlock()));
4087   llvm::Value* block_id_logical = b_.CreateAdd(
4088       b_.CreateMul(block_id_physical,
4089                    constant(tiling_scheme.GetThreadIdScalingFactor())),
4090       scaling);
4091 
4092   llvm::Value* num_threads_x_v =
4093       constant(tiling_scheme.GetNumThreadsFor(kDimX));
4094 
4095   llvm::Value* block_exists = b_.CreateICmpULT(
4096       block_id_logical, constant(tiling_scheme.GetNumberOfBlocks()));
4097   llvm_ir::EmitEarlyReturn(block_exists, &b_);
4098   return {{thread_id_logical,
4099            /*thread_id_x=*/
4100            b_.CreateURem(thread_id_logical, num_threads_x_v, "thread_id.x"),
4101            /*thread_id_y=*/
4102            b_.CreateUDiv(thread_id_logical, num_threads_x_v, "thread_id.y"),
4103            /*lane_id=*/
4104            b_.CreateURem(thread_id_logical, constant(WarpSize()), "lane_id"),
4105            /*block_id=*/block_id_logical,
4106            /*scaling=*/scaling}};
4107 }
4108 
4109 StatusOr<IrEmitterUnnested::TilingKernelInfo>
EmitTilingKernel(const TilingScheme & tiling_scheme,llvm::Type * index_ty,const TileElementGenerator & tile_element_generator)4110 IrEmitterUnnested::EmitTilingKernel(
4111     const TilingScheme& tiling_scheme, llvm::Type* index_ty,
4112     const TileElementGenerator& tile_element_generator) {
4113   absl::Span<const int64_t> dims_in_elems = tiling_scheme.GetDimsInElems();
4114   Vector3 dims_in_blocks = tiling_scheme.GetDimsInBlocks();
4115   auto constant = [&](uint64_t c) -> llvm::Constant* {
4116     return llvm::ConstantInt::get(index_ty, c);
4117   };
4118 
4119   TF_ASSIGN_OR_RETURN(ThreadIdInfo thread_id_info,
4120                       EmitThreadIdInfo(tiling_scheme, index_ty));
4121 
4122   KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
4123 
4124   const IrArray::Index block_coords = [&] {
4125     IrArray::Index starting_block(thread_id_info.block_id,
4126                                   ShapeUtil::MakeShapeWithDescendingLayout(
4127                                       PRED /*arbitrary*/, dims_in_blocks),
4128                                   &b_);
4129     std::vector<llvm::Value*> multidim = {
4130         b_.CreateMul(starting_block[0],
4131                      constant(tiling_scheme.GetBlockTileSizeFor(0)),
4132                      "block_origin.z"),
4133         starting_block[1], starting_block[2]};
4134     return IrArray::Index(multidim, dims_in_blocks, index_ty);
4135   }();
4136 
4137   ValueVector3 tile_dimensions;
4138   for (int i = kDimY; i < kDimTot; ++i) {
4139     int64_t tile_size_for_dim = tiling_scheme.GetBlockTileSizeFor(i);
4140     // Only last row or column may not have full size.
4141     llvm::Value* is_last =
4142         b_.CreateICmpEQ(block_coords[i], constant(dims_in_blocks[i] - 1));
4143     int64_t partial_row =
4144         dims_in_elems[i] - (dims_in_blocks[i] - 1) * tile_size_for_dim;
4145     tile_dimensions[i] =
4146         b_.CreateSelect(is_last, constant(partial_row),
4147                         constant(tile_size_for_dim), "tile_bound");
4148   }
4149 
4150   IrArray::Index tile_origin = [&] {
4151     std::vector<llvm::Value*> elem_multi_index = block_coords.multidim();
4152     llvm::Type* index_ty = block_coords.GetType();
4153     for (int i = kDimY; i < kDimTot; ++i) {
4154       elem_multi_index[i] =
4155           b_.CreateMul(block_coords[i],
4156                        llvm::ConstantInt::get(
4157                            index_ty, tiling_scheme.GetBlockTileSizeFor(i)),
4158                        "tile_origin." + std::to_string(i));
4159     }
4160     return IrArray::Index(elem_multi_index, tiling_scheme.GetDimsInElems(),
4161                           index_ty);
4162   }();
4163 
4164   auto emit_tile = [&](const IrArray::Index& tile) {
4165     tile_element_generator(thread_id_info, tile, tile_dimensions);
4166   };
4167 
4168   if (tiling_scheme.GetBlockTileSizeFor(kDimZ) == 1) {
4169     emit_tile(tile_origin);
4170   } else {
4171     llvm::Value* starting_tile_index_for_dim = tile_origin[kDimZ];
4172     llvm::Value* block_size_for_dim =
4173         constant(tiling_scheme.GetBlockTileSizeFor(kDimZ));
4174     llvm::Value* block_id_for_dim =
4175         b_.CreateUDiv(starting_tile_index_for_dim, block_size_for_dim);
4176     llvm::Value* last_block_for_dim = constant(dims_in_blocks[kDimZ] - 1);
4177     llvm::Value* last_block_size_for_dim = constant(
4178         dims_in_elems[kDimZ] -
4179         (dims_in_blocks[kDimZ] - 1) * tiling_scheme.GetBlockTileSizeFor(kDimZ));
4180 
4181     llvm::Value* num_tiles_in_block =
4182         b_.CreateSelect(b_.CreateICmpEQ(last_block_for_dim, block_id_for_dim),
4183                         last_block_size_for_dim, block_size_for_dim);
4184     ksl.For("loop_z",
4185             /*start=*/constant(0),
4186             /*end=*/num_tiles_in_block,
4187             /*step=*/1, [&](llvm::Value* block_dim_induction_var) {
4188               IrArray::Index tile_index = tile_origin.AddOffsetToDim(
4189                   block_dim_induction_var, kDimZ, &b_);
4190               emit_tile(tile_index);
4191             });
4192   }
4193 
4194   return {{tile_dimensions, tile_origin, thread_id_info}};
4195 }
4196 
EmitSyncThreads()4197 llvm::CallInst* IrEmitterUnnested::EmitSyncThreads() {
4198   return EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_);
4199 }
4200 
EmitHlo021Tile(mlir::lmhlo::FusionOp fusion,absl::Span<const llvm_ir::IrArray> operand_arrays,absl::Span<const llvm_ir::IrArray> output_arrays,TransposeDimsAndParams descr,const TilingScheme & tiling_scheme,const LaunchDimensions & launch_dimensions)4201 Status IrEmitterUnnested::EmitHlo021Tile(
4202     mlir::lmhlo::FusionOp fusion,
4203     absl::Span<const llvm_ir::IrArray> operand_arrays,
4204     absl::Span<const llvm_ir::IrArray> output_arrays,
4205     TransposeDimsAndParams descr, const TilingScheme& tiling_scheme,
4206     const LaunchDimensions& launch_dimensions) {
4207   std::string name = GetIrNameFromLoc(fusion.getLoc());
4208 
4209   llvm::Type* index_type = GetIndexTypeForKernel(
4210       fusion.getOperation(), launch_dimensions.launch_bound(), &b_);
4211 
4212   // For each tiled parameter, cast its input IrArray to the corresponding
4213   // reduced shape and keep the reduced shape live during IR emission.
4214   std::map<int64_t, IrArray> param_in_reduced_shape_arrays;
4215   std::vector<llvm::Value*> param_shmem_buffers(fusion.getInputBuffers().size(),
4216                                                 nullptr);
4217 
4218   auto get_shared_memory_buffer = [&](llvm::Type* elem_ty,
4219                                       absl::string_view buffer_name) {
4220     // For Nvidia GPUs, the warp size is 32 threads and the shared memory bank
4221     // is organized into 32-way. We usually use the warp size or a multiplier or
4222     // a the warp size as the size for tiling. This may cause all elements in
4223     // the same column of a tile use the same memory bank and therefore shared
4224     // memory bank conflicts. Adding 1 to the minor dimension of the shared
4225     // memory buffer can reduce such shared memory bank conflicts.
4226     return AllocateShared(tiling_scheme, elem_ty,
4227                           {tiling_scheme.GetBlockTileSizeFor(kDimY),
4228                            tiling_scheme.GetBlockTileSizeFor(kDimX) + 1},
4229                           buffer_name);
4230   };
4231 
4232   for (int64_t id : descr.params) {
4233     const Shape& param_shape = GetShape(fusion.getInputBuffers()[id]);
4234     param_shmem_buffers[id] = get_shared_memory_buffer(
4235         llvm_ir::PrimitiveTypeToIrType(param_shape.element_type(), module_),
4236         IrName(name, StrCat("tile", id)));
4237     VLOG(3) << "Added shmem buffer for parameter " << id << ": "
4238             << llvm_ir::DumpToString(*param_shmem_buffers[id]);
4239     Shape reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
4240         param_shape.element_type(), Permute(descr.dims, {0, 2, 1}));
4241     param_in_reduced_shape_arrays[id] =
4242         operand_arrays[id].CastToShape(reduced_shape, &b_);
4243   }
4244 
4245   TileElementGenerator tile_generator =
4246       [&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index,
4247           ValueVector3 tile_dimensions) {
4248         // If shared memory transpose is needed, wait for all threads to reach
4249         // this point, lest we copy a value from tile to output before the other
4250         // thread copies it from input to tile. This is `__syncthreads` in CUDA.
4251         if (!descr.params.empty()) {
4252           // Calculate the input tile origin from the output tile origin.
4253           const IrArray::Index input_tile_origin(
4254               Permute(index.multidim(), {0, 2, 1}),
4255               Permute(index.dims(), {0, 2, 1}), index.GetType());
4256 
4257           ValueVector3 transposed_tile_dimensions = {tile_dimensions[kDimZ],
4258                                                      tile_dimensions[kDimX],
4259                                                      tile_dimensions[kDimY]};
4260 
4261           // Copy input parameter values to shared memory buffers:
4262           // tile[thread_id_y, thread_id_x] = input[index]
4263           // Note that tile_width and tile_height are flipped here because we
4264           // are reading a transposed tile.
4265           EmitTile(tiling_scheme, input_tile_origin, thread_id_info,
4266                    transposed_tile_dimensions,
4267                    [&](const ThreadIdInfo& thread_id_info,
4268                        const IrArray::Index& index, llvm::Value* y_loc,
4269                        llvm::Value* x_loc, llvm::Value* /*x_iter_num*/) {
4270                      for (int64_t id : descr.params) {
4271                        IrArray& input_in_logical_shape =
4272                            param_in_reduced_shape_arrays[id];
4273 
4274                        auto shmem_buffer = llvm::cast<llvm::GlobalVariable>(
4275                            param_shmem_buffers.at(id));
4276                        llvm::Value* value =
4277                            input_in_logical_shape.EmitReadArrayElement(
4278                                index, &b_, "input_element");
4279 
4280                        llvm::Value* addr = thread_id_info.GEPIntoSharedMemory(
4281                            &b_, shmem_buffer, {y_loc, x_loc});
4282                        b_.CreateStore(value, addr);
4283                      }
4284                    });
4285 
4286           // Wait for all threads to reach this point using `__syncthreads` in
4287           // CUDA.
4288           EmitSyncThreads();
4289         }
4290 
4291         EmitTile(tiling_scheme, index, thread_id_info, tile_dimensions,
4292                  /*emit_elem_function=*/
4293                  [&](const ThreadIdInfo& thread_id_info,
4294                      const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
4295                      llvm::Value* x_loc, llvm::Value* /*x_iter_num*/) {
4296                    EmitTileElementForTranspose(
4297                        thread_id_info, fusion, operand_arrays, output_arrays,
4298                        index, tiling_scheme, y_loc, x_loc, param_shmem_buffers);
4299                  });
4300         bool block_contains_multi_tiles =
4301             tiling_scheme.GetBlockTileSizeFor(kDimZ) > 1;
4302 
4303         // If a tile block contains multiple tiles and shared memory buffers are
4304         // used, we need to wait for all threads to finish using the shared
4305         // memory buffer for the current tile before we move on to process the
4306         // next tile and overwrite the shared memory buffers.
4307         if (block_contains_multi_tiles && !descr.params.empty()) {
4308           EmitSyncThreads();
4309         }
4310       };
4311 
4312   return EmitTilingKernel(tiling_scheme, index_type, tile_generator).status();
4313 }
4314 
Emit021Transpose(TransposeDimsAndParams descr,mlir::lmhlo::FusionOp fusion)4315 Status IrEmitterUnnested::Emit021Transpose(TransposeDimsAndParams descr,
4316                                            mlir::lmhlo::FusionOp fusion) {
4317   constexpr int kNumRows = 4;
4318   CHECK_EQ(WarpSize() % kNumRows, 0);
4319   TilingScheme tiling_scheme(descr.dims,
4320                              /*tile_sizes=*/{1, WarpSize() / kNumRows, 1},
4321                              /*num_threads=*/{1, kNumRows, WarpSize()},
4322                              /*indexing_order=*/kLinearIndexingX,
4323                              /*vector_size=*/1,
4324                              /*scaling_factor=*/1);
4325   LaunchDimensions launch_dimensions(
4326       tiling_scheme.GetNumberOfBlocksPhysical(),
4327       tiling_scheme.GetNumThreadsPerBlockPhysical());
4328   std::vector<llvm_ir::IrArray> ir_arrays;
4329   TF_ASSIGN_OR_RETURN(auto kernel_thunk,
4330                       BuildKernelThunk(fusion, GetThunkInfo(fusion), &ir_arrays,
4331                                        launch_dimensions));
4332   TF_RETURN_IF_ERROR(EmitHlo021Tile(
4333       fusion,
4334       absl::MakeSpan(ir_arrays).subspan(0, fusion.getInputBuffers().size()),
4335       absl::MakeSpan(ir_arrays).subspan(fusion.getInputBuffers().size()), descr,
4336       tiling_scheme, launch_dimensions));
4337   AddThunkToThunkSequence(std::move(kernel_thunk));
4338   return Status::OK();
4339 }
4340 
4341 namespace {
4342 
4343 // Returns true if all the transitive users of hlo before hitting users in
4344 // use_chain_endings are elementwise operations.
AreUsersElementwise(mlir::Value value,const absl::flat_hash_set<mlir::Operation * > & use_chain_endings)4345 bool AreUsersElementwise(
4346     mlir::Value value,
4347     const absl::flat_hash_set<mlir::Operation*>& use_chain_endings) {
4348   return absl::c_all_of(value.getUsers(), [&](mlir::OpOperand use) {
4349     mlir::Operation* user = use.getOwner();
4350     return use_chain_endings.count(user) ||
4351            (HloInstruction::IsOpElementwise(*MhloToHloOpcode(user)) &&
4352             absl::c_all_of(user->getResults(),
4353                            [&](const mlir::OpResult result) {
4354                              return AreUsersElementwise(result,
4355                                                         use_chain_endings);
4356                            })
4357 
4358            );
4359   });
4360 }
4361 
4362 // Returns the number of fusion inputs that have the same dimension as the
4363 // given shape, and involve in only elementwise operations.
NumInputsInvolveInOnlyElementwiseOps(mlir::lmhlo::FusionOp fusion,const Shape & op_shape,const absl::flat_hash_set<mlir::Operation * > & use_chain_endings)4364 int64_t NumInputsInvolveInOnlyElementwiseOps(
4365     mlir::lmhlo::FusionOp fusion, const Shape& op_shape,
4366     const absl::flat_hash_set<mlir::Operation*>& use_chain_endings) {
4367   return absl::c_count_if(
4368       fusion.getFusionParameters(), [&](mlir::Value parameter) {
4369         Shape parameter_shape = GetShape(parameter);
4370         return ShapeUtil::SameDimensions(op_shape, parameter_shape) &&
4371                AreUsersElementwise(parameter, use_chain_endings);
4372       });
4373 }
4374 
4375 // Returns the number of fusion inputs that have more elements than the given
4376 // shape.
NumInputsWithMoreElementsThan(mlir::lmhlo::FusionOp fusion,const Shape & shape)4377 int64_t NumInputsWithMoreElementsThan(mlir::lmhlo::FusionOp fusion,
4378                                       const Shape& shape) {
4379   int64_t num_elements = ShapeUtil::ElementsIn(shape);
4380   return absl::c_count_if(
4381       fusion.getFusionParameters(), [&](mlir::Value parameter) {
4382         Shape parameter_shape = GetShape(parameter);
4383         return ShapeUtil::ElementsIn(parameter_shape) > num_elements;
4384       });
4385 }
4386 
4387 // The benefit of unrolling a kInput fusion that is a column reduction comes
4388 // from the vectorization of non-reduction fusion outputs and fusion inputs.
4389 // On the other hand, unrolling can also introduce factors that can cause
4390 // the kernel to run slower. This routine uses a simple heuristic to estimate
4391 // the benefit as well as the overhead of unrolling in order to decide whether
4392 // unrolling is beneficial for the given kInput fusion.
IsUnrollingColumnReductionBeneficial(mlir::lmhlo::FusionOp fusion,const Shape & input_shape,int64_t num_kept_minor)4393 bool IsUnrollingColumnReductionBeneficial(mlir::lmhlo::FusionOp fusion,
4394                                           const Shape& input_shape,
4395                                           int64_t num_kept_minor) {
4396   if (num_kept_minor % (WarpSize() * 2) != 0) {
4397     return false;
4398   }
4399 
4400   if (input_shape.dimensions()[input_shape.rank() - 1] < 64) {
4401     return false;
4402   }
4403 
4404   int64_t can_be_vectorized = 0;
4405   int64_t cannot_be_vectorized = 0;
4406   llvm::SmallVector<mlir::Operation*> fusion_roots = fusion.getFusionRoots();
4407   absl::flat_hash_set<mlir::Operation*> use_chain_endings;
4408   if (fusion_roots.size() == 1) {
4409     if (IsReductionFromOrToContiguousDimensions(fusion_roots[0])) {
4410       use_chain_endings.insert(fusion_roots[0]);
4411       // Atomic.add of the reduction result can't be vectorized.
4412       cannot_be_vectorized++;
4413     }
4414   } else {
4415     for (mlir::Operation* op : fusion_roots) {
4416       if (IsReductionFromOrToContiguousDimensions(op)) {
4417         // Atomic.add of the reduction result can't be vectorized.
4418         cannot_be_vectorized++;
4419       } else {
4420         // Write of the non-reduction result can be vectorized.
4421         can_be_vectorized++;
4422       }
4423       use_chain_endings.insert(op);
4424     }
4425   }
4426   // Fusion inputs that have the same dimension as the reduce input and
4427   // only involve in elementwise operations can be vectorized.
4428   can_be_vectorized += NumInputsInvolveInOnlyElementwiseOps(fusion, input_shape,
4429                                                             use_chain_endings);
4430   // Fusion inputs with more elements than the reduce op input must participate
4431   // in non-elementwise operations and we assume that they are not vectorizable
4432   // for the purpose of estimating the benefit of unrolling. If the kernel is
4433   // unrolled even with such an assumption,  and the accesses to those inputs
4434   // turn out to be vectorizable, the compiler will still vectorize them.
4435   cannot_be_vectorized += NumInputsWithMoreElementsThan(fusion, input_shape);
4436   return can_be_vectorized >= cannot_be_vectorized;
4437 }
4438 
NearestPowerOfTwo(int64_t v)4439 int64_t NearestPowerOfTwo(int64_t v) {
4440   if (v < 0) {
4441     return 0;
4442   }
4443   int64_t upper = absl::bit_ceil<uint64_t>(v);
4444   int64_t lower = upper >> 1;
4445   return upper - v < v - lower ? upper : lower;
4446 }
4447 
4448 }  // namespace
4449 
4450 // Returns primitive bitwidth for shape of the value.
GetPrimitiveBitwidth(mlir::Value i)4451 static int GetPrimitiveBitwidth(mlir::Value i) {
4452   // TODO(timshen): may not be efficient.
4453   return primitive_util::BitWidth(GetShape(i).element_type());
4454 }
4455 
4456 // Experimentally determined values to achieve optimal number of
4457 // bytes-in-flight. With a bound of #warps/SM which can be concurrently
4458 // scheduled, for small reduced values it can be hard to achieve optimal
4459 // number of bytes-in-flight. In order to address it, we increase the # of
4460 // threads/block (physically, while keeping logical mapping the same), which
4461 // allows larger # of bytes-in-flight.
CalculateVirtualThreadScalingFactorForReduction(const ReductionDimensions & reduction_dimensions,const se::CudaComputeCapability & cc)4462 static int CalculateVirtualThreadScalingFactorForReduction(
4463     const ReductionDimensions& reduction_dimensions,
4464     const se::CudaComputeCapability& cc) {
4465   int dimx = reduction_dimensions.dimensions[kDimX];
4466   if (reduction_dimensions.is_row_reduction && dimx <= 128) {
4467     int rows_per_warp = RowReductionGetRowsPerWarp(dimx);
4468     if (cc.IsAtLeast(se::CudaComputeCapability::AMPERE)) {
4469       return rows_per_warp * 3;
4470     }
4471     return rows_per_warp * 5;
4472   }
4473   return 1;
4474 }
4475 
GEPIntoSharedMemoryType(llvm::GlobalVariable * shared,absl::Span<llvm::Value * const> idx_major_to_minor) const4476 llvm::Type* IrEmitterUnnested::ThreadIdInfo::GEPIntoSharedMemoryType(
4477     llvm::GlobalVariable* shared,
4478     absl::Span<llvm::Value* const> idx_major_to_minor) const {
4479   std::vector<llvm::Value*> idxs_scaled;
4480   idxs_scaled.push_back(llvm::ConstantInt::get(scaling->getType(), 0));
4481   idxs_scaled.push_back(scaling);
4482   idxs_scaled.insert(idxs_scaled.end(), idx_major_to_minor.begin(),
4483                      idx_major_to_minor.end());
4484   return llvm::GetElementPtrInst::getIndexedType(shared->getValueType(),
4485                                                  idxs_scaled);
4486 }
4487 
GEPIntoSharedMemory(llvm::IRBuilder<> * b,llvm::GlobalVariable * shared,absl::Span<llvm::Value * const> idx_major_to_minor,const llvm::Twine & name) const4488 llvm::Value* IrEmitterUnnested::ThreadIdInfo::GEPIntoSharedMemory(
4489     llvm::IRBuilder<>* b, llvm::GlobalVariable* shared,
4490     absl::Span<llvm::Value* const> idx_major_to_minor,
4491     const llvm::Twine& name) const {
4492   std::vector<llvm::Value*> idxs_scaled;
4493   idxs_scaled.push_back(llvm::ConstantInt::get(scaling->getType(), 0));
4494   idxs_scaled.push_back(scaling);
4495   idxs_scaled.insert(idxs_scaled.end(), idx_major_to_minor.begin(),
4496                      idx_major_to_minor.end());
4497   llvm::Value* gep =
4498       b->CreateInBoundsGEP(shared->getValueType(), shared, idxs_scaled, name);
4499 
4500   llvm::PointerType* pointer_in_addressspace =
4501       llvm::PointerType::getWithSamePointeeType(
4502           llvm::cast<llvm::PointerType>(gep->getType()), /*AddressSpace=*/0);
4503 
4504   // __shared__ memory uses a different address space, so we cast it to
4505   // global address space before writing or reading.
4506   return b->CreateAddrSpaceCast(gep, pointer_in_addressspace);
4507 }
4508 
AllocateShared(const TilingScheme & tiling_scheme,llvm::Type * element_type,absl::Span<int64_t const> dimensions_major_to_minor,absl::string_view buffer_name)4509 llvm::GlobalVariable* IrEmitterUnnested::AllocateShared(
4510     const TilingScheme& tiling_scheme, llvm::Type* element_type,
4511     absl::Span<int64_t const> dimensions_major_to_minor,
4512     absl::string_view buffer_name) {
4513   CHECK(!dimensions_major_to_minor.empty());
4514   llvm::Type* array_type = nullptr;
4515   for (int i = dimensions_major_to_minor.size() - 1; i >= 0; i--) {
4516     // Iterate in minor-to-major order.
4517     int64_t dim = dimensions_major_to_minor[i];
4518     if (!array_type) {
4519       array_type = llvm::ArrayType::get(element_type, dim);
4520     } else {
4521       array_type = llvm::ArrayType::get(array_type, dim);
4522     }
4523   }
4524   array_type = llvm::ArrayType::get(array_type,
4525                                     tiling_scheme.GetThreadIdScalingFactor());
4526   return llvm_ir::AllocateSharedMemoryTile(b_.GetInsertBlock()->getModule(),
4527                                            array_type, buffer_name);
4528 }
4529 
4530 // Whether the reduction can be vectorized.
CanVectorizeReduction(se::CudaComputeCapability cc,mlir::lmhlo::FusionOp fusion,const ReductionDimensions & reduction_dimensions,int num_threads_x,std::array<int64_t,3> reduction_tiling,const Shape & input_shape)4531 static bool CanVectorizeReduction(
4532     se::CudaComputeCapability cc, mlir::lmhlo::FusionOp fusion,
4533     const ReductionDimensions& reduction_dimensions, int num_threads_x,
4534     std::array<int64_t, 3> reduction_tiling, const Shape& input_shape) {
4535   if (!reduction_dimensions.is_row_reduction) {
4536     return IsUnrollingColumnReductionBeneficial(
4537         fusion, input_shape, reduction_dimensions.dimensions[kDimX]);
4538   }
4539 
4540   if (reduction_dimensions.dimensions[kDimX] % 2 != 0 ||
4541       MayPreventVectorization(fusion)) {
4542     return false;
4543   }
4544 
4545   // Enabling vectorization if number of threads is <= warpsize leads to half or
4546   // more of the threads not doing any work.
4547   if (reduction_dimensions.is_row_reduction && num_threads_x <= WarpSize()) {
4548     return false;
4549   }
4550 
4551   if (cc.IsAtLeast(se::CudaComputeCapability::VOLTA)) {
4552     return true;
4553   }
4554 
4555   int smallest_input_dtype_bits = std::numeric_limits<int>::max();
4556   for (mlir::Value operand : fusion.getInputBuffers()) {
4557     smallest_input_dtype_bits =
4558         std::min(GetPrimitiveBitwidth(operand), smallest_input_dtype_bits);
4559   }
4560   if (cc.IsAtLeast(se::CudaComputeCapability::PASCAL_)) {
4561     return smallest_input_dtype_bits <= 32 &&
4562            reduction_dimensions.dimensions[kDimX] %
4563                    (reduction_tiling[2] * num_threads_x) ==
4564                0;
4565   }
4566   return false;
4567 }
4568 
ComputeReductionCodegenInfo(mlir::lmhlo::FusionOp fusion,mlir::mhlo::ReduceOp first_reduce)4569 StatusOr<ReductionCodegenInfo> IrEmitterUnnested::ComputeReductionCodegenInfo(
4570     mlir::lmhlo::FusionOp fusion, mlir::mhlo::ReduceOp first_reduce) {
4571   Shape input_shape = GetShape(first_reduce->getOperand(0));
4572   ReductionDimensions reduction_dimensions =
4573       GetReductionKindAndContiguousComponents(first_reduce);
4574   VLOG(10) << "is_row_reduction " << reduction_dimensions.is_row_reduction
4575            << " " << reduction_dimensions.dimensions[0] << " "
4576            << reduction_dimensions.dimensions[1] << " "
4577            << reduction_dimensions.dimensions[2];
4578   Vector3 reduction_tiling = GetReductionTiling(
4579       reduction_dimensions, ir_emitter_context_->cuda_compute_capability());
4580 
4581   int64_t num_threads_y =
4582       reduction_dimensions.is_row_reduction ? 1 : WarpSize();
4583   int64_t num_threads_x = [&] {
4584     if (reduction_dimensions.is_row_reduction) {
4585       if (RowReductionGetRowsPerWarp(reduction_dimensions.dimensions[2]) > 1) {
4586         return reduction_dimensions.dimensions[2];
4587       }
4588       // Use 512 as default block size (threads per block) for row reductions.
4589       // For multi-output fusions, reduce the block size further to decrease
4590       // register pressure when multiple outputs are computed by each thread.
4591       int64_t fan_out = fusion.getFusionRoots().size();
4592       int64_t max_block_size =
4593           std::max(MinThreadsXRowReduction(),
4594                    static_cast<int64_t>(512LL / NearestPowerOfTwo(fan_out)));
4595       return std::min(max_block_size,
4596                       RoundUpTo(CeilOfRatio(reduction_dimensions.dimensions[2],
4597                                             reduction_tiling[2]),
4598                                 WarpSize()));
4599     }
4600     return WarpSize();
4601   }();
4602 
4603   se::CudaComputeCapability cc = ir_emitter_context_->cuda_compute_capability();
4604 
4605   int smallest_input_dtype_bits = std::numeric_limits<int>::max();
4606   for (mlir::Value operand : fusion.getInputBuffers()) {
4607     smallest_input_dtype_bits =
4608         std::min(GetPrimitiveBitwidth(operand), smallest_input_dtype_bits);
4609   }
4610 
4611   TilingScheme::IndexingOrder indexing_order =
4612       reduction_dimensions.is_row_reduction ? kStridedIndexingX
4613                                             : kLinearIndexingX;
4614   bool vectorize =
4615       CanVectorizeReduction(cc, fusion, reduction_dimensions, num_threads_x,
4616                             reduction_tiling, input_shape);
4617   int vector_size = vectorize ? 2 : 1;
4618   int num_partial_results =
4619       !reduction_dimensions.is_row_reduction && vectorize ? 2 : 1;
4620   VLOG(3) << "Each threads will produce " << num_partial_results
4621           << " output(s)";
4622   reduction_tiling[kDimX] *= num_partial_results;
4623 
4624   Vector3 num_threads = {1, num_threads_y, num_threads_x};
4625   int virtual_thread_scaling_factor =
4626       CalculateVirtualThreadScalingFactorForReduction(reduction_dimensions, cc);
4627   VLOG(2) << "Using virtual thread scaling: " << virtual_thread_scaling_factor;
4628 
4629   TilingScheme tiling_scheme(reduction_dimensions.dimensions, reduction_tiling,
4630                              num_threads, indexing_order, vector_size,
4631                              virtual_thread_scaling_factor);
4632   return ReductionCodegenInfo(
4633       tiling_scheme, num_partial_results, reduction_dimensions.is_row_reduction,
4634       ReductionIsRaceFree(reduction_dimensions, reduction_tiling));
4635 }
4636 
4637 // Generate a single element of the tile (update the accumulator state) for a
4638 // given reducer of index `i`.
GenerateElementForReducer(const HloReduceInstruction * reduction,llvm::Value * partial_result_index,const ReductionCodegenState & codegen_state,const llvm_ir::IrArray::Index & index_without_linear,const IrArray::Index & input_index,int num_partial_results,const ReductionOutputMap & result_ir_arrays)4639 void IrEmitterUnnested::GenerateElementForReducer(
4640     const HloReduceInstruction* reduction, llvm::Value* partial_result_index,
4641     const ReductionCodegenState& codegen_state,
4642     const llvm_ir::IrArray::Index& index_without_linear,
4643     const IrArray::Index& input_index, int num_partial_results,
4644     const ReductionOutputMap& result_ir_arrays) {
4645   HloComputation* reducer = reduction->to_apply();
4646   CHECK_EQ(reducer->num_parameters() % 2, 0);
4647 
4648   absl::InlinedVector<llvm::Value*, 2> reduction_accumulators;
4649   absl::InlinedVector<llvm::Value*, 2> reduction_input_value;
4650   for (int red_idx = 0; red_idx < reducer->num_parameters() / 2; red_idx++) {
4651     const ReductionCodegenState::ReductionCalculationState& state =
4652         codegen_state.GetCalculationStateFor(reduction, red_idx);
4653 
4654     llvm::AllocaInst* input_address = state.input_address;
4655     llvm::AllocaInst* partial_reduction_result_address =
4656         state.partial_result_address;
4657     llvm::Value* const input_ir_value = *state.input_gen(
4658         num_partial_results > 1 ? index_without_linear : input_index);
4659     b_.CreateStore(input_ir_value, input_address);
4660     llvm::Value* partial_result_address =
4661         InBoundsGEP(partial_reduction_result_address->getAllocatedType(),
4662                     partial_reduction_result_address, {partial_result_index});
4663     reduction_accumulators.push_back(partial_result_address);
4664     reduction_input_value.push_back(input_address);
4665   }
4666 
4667   absl::InlinedVector<llvm::Value*, 4> reduction_params;
4668   for (llvm::Value* acc : reduction_accumulators) {
4669     reduction_params.push_back(acc);
4670   }
4671   for (llvm::Value* value : reduction_input_value) {
4672     reduction_params.push_back(value);
4673   }
4674 
4675   // Emit a call to the variadic reducer. Since it may be returning a
4676   // tuple, we can't return it directly as a value. Instead, before
4677   // the call, we create N (N = # arguments in the tuple) allocas, one
4678   // for each returned argument, then when we make the call we pass N
4679   // pointers as last parameters, the called computation writes into
4680   // those pointers, and we have returned values on the stack (as well
4681   // as pointers to them).
4682   StatusOr<std::vector<llvm::Value*>> returned_scalars =
4683       ComputeNestedElementFromAddrs(*reducer, reduction_params);
4684   TF_CHECK_OK(returned_scalars.status());
4685 
4686   for (int i = 0; i < returned_scalars->size(); i++) {
4687     b_.CreateStore(returned_scalars->at(i), reduction_accumulators[i]);
4688   }
4689 }
4690 
EmitIRForReduction(mlir::lmhlo::FusionOp fusion,absl::Span<HloInstruction * const> instr_index_group,FusedIrEmitter & fused_emitter,const ReductionOutputMap & result_ir_arrays,const ReductionCodegenInfo & reduction_info,const Shape & input_shape)4691 Status IrEmitterUnnested::EmitIRForReduction(
4692     mlir::lmhlo::FusionOp fusion,
4693     absl::Span<HloInstruction* const> instr_index_group,
4694     FusedIrEmitter& fused_emitter, const ReductionOutputMap& result_ir_arrays,
4695     const ReductionCodegenInfo& reduction_info, const Shape& input_shape) {
4696   std::vector<const HloReduceInstruction*> reductions;
4697   ExtraOutputGensMap extra_output_gens;
4698   for (const HloInstruction* hlo : instr_index_group) {
4699     if (IsReductionFromOrToContiguousDimensions(*hlo)) {
4700       reductions.push_back(Cast<HloReduceInstruction>(hlo));
4701     } else {
4702       extra_output_gens[hlo] = *fused_emitter.GetGenerator(*hlo);
4703     }
4704   }
4705 
4706   CHECK(!reductions.empty()) << " expect at least one reduce instructions.";
4707   const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme();
4708   CHECK_EQ(tiling_scheme.GetNumThreadsPerBlockPhysical() % WarpSize(), 0);
4709   llvm::Type* index_ty =
4710       GetIndexTypeForKernel(fusion,
4711                             tiling_scheme.GetNumThreadsPerBlockPhysical() *
4712                                 tiling_scheme.GetNumberOfBlocksPhysical(),
4713                             &b_);
4714   ReductionCodegenState codegen_state = GenerateReductionCodegenState(
4715       fusion, reduction_info, reductions, fused_emitter);
4716 
4717   EmitElementFunction emit_reduction_element =
4718       [&](const ThreadIdInfo& thread_id_info,
4719           const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
4720           llvm::Value* x_loc, llvm::Value* x_iter_num) {
4721         IrArray::Index input_index = GetUnnormalizedIndex(
4722             index, input_shape, &b_, codegen_state.GetTilingScheme());
4723 
4724         llvm::Value* partial_result_index =
4725             codegen_state.IsRowReduction() ? b_.getInt32(0) : x_iter_num;
4726 
4727         // Clear the linear index field of the IrArray::Index to enable the use
4728         // of GetElementPointer with array types. This enables the vectorization
4729         // of the computation for different partial results. Use this index if
4730         // 'num_partial_results > 1'.
4731         int num_partial_results = codegen_state.GetNumPartialResults();
4732         llvm_ir::IrArray::Index index_without_linear = IrArray::Index(
4733             input_index.multidim(), input_shape, input_index.GetType());
4734 
4735         // Emit code to generate the input and perform the reduction computation
4736         // for each reduction instruction.
4737         for (const HloReduceInstruction* reduce : reductions) {
4738           GenerateElementForReducer(reduce, partial_result_index, codegen_state,
4739                                     index_without_linear, input_index,
4740                                     num_partial_results, result_ir_arrays);
4741         }
4742 
4743         // Emit code to generate the output for the non-reduction instructions
4744         // in the fusion, if any.
4745         TF_CHECK_OK(EmitExtraOutputsForReduce(
4746             result_ir_arrays, input_index, reduction_info, extra_output_gens));
4747       };
4748 
4749   TF_ASSIGN_OR_RETURN(
4750       TilingKernelInfo tiling_kernel_info,
4751       EmitTilingKernel(
4752           tiling_scheme, index_ty,
4753           [&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index,
4754               ValueVector3 tile_dimensions) {
4755             EmitTile(codegen_state.GetTilingScheme(), index, thread_id_info,
4756                      tile_dimensions, emit_reduction_element);
4757           }));
4758 
4759   KernelSupportLibrary ksl(&b_);
4760   for (const HloReduceInstruction* reduce : reductions) {
4761     for (int partial_result_idx = 0;
4762          partial_result_idx < reduction_info.GetNumPartialResults();
4763          ++partial_result_idx) {
4764       if (codegen_state.IsRowReduction()) {
4765         EmitReductionOutputForRowReduction(tiling_kernel_info, codegen_state,
4766                                            index_ty, result_ir_arrays, reduce,
4767                                            partial_result_idx);
4768       } else {
4769         EmitReductionOutputForColumnReduction(tiling_kernel_info, codegen_state,
4770                                               index_ty, result_ir_arrays,
4771                                               reduce, partial_result_idx);
4772       }
4773     }
4774   }
4775 
4776   return OkStatus();
4777 }
4778 
4779 namespace {
4780 
4781 // Returns whether the `instr` is either a constant, a scalar, or a
4782 // broadcasted constant/scalar.
IsBroadcastedConstantOrScalar(const HloInstruction & instr)4783 bool IsBroadcastedConstantOrScalar(const HloInstruction& instr) {
4784   return instr.IsConstant() || ShapeUtil::IsScalar(instr.shape()) ||
4785          (HloOpcode::kBroadcast == instr.opcode() &&
4786           (instr.operand(0)->IsConstant() ||
4787            ShapeUtil::IsScalar(instr.operand(0)->shape())));
4788 }
4789 
4790 // Recursive helper for GetFusionRoots below.
GetFusionRootsRec(HloInstruction * root,std::vector<HloInstruction * > & out)4791 static void GetFusionRootsRec(HloInstruction* root,
4792                               std::vector<HloInstruction*>& out) {
4793   if (root->opcode() == HloOpcode::kGetTupleElement) {
4794     return GetFusionRootsRec(root->mutable_operand(0), out);
4795   } else if (root->opcode() == HloOpcode::kTuple) {
4796     for (int i = 0; i < root->operand_count(); i++) {
4797       GetFusionRootsRec(root->mutable_operand(i), out);
4798     }
4799   } else {
4800     if (!out.empty() && out.back() == root) {
4801       return;
4802     }
4803     CHECK(!absl::c_linear_search(out, root))
4804         << "Fusion root contains instruction " << root->ToString()
4805         << " multiple times";
4806     out.push_back(root);
4807   }
4808 }
4809 
4810 // Returns instructions which are roots of the fusion, following the operands of
4811 // GTE instructions in the root tuple. Groups multiple subsequent instruction
4812 // with the same root. CHECKs that the fusion never outputs the same instruction
4813 // twice, as well as that there are no explicitly created tuples or nested gtes
4814 // in fusion output.
4815 //
4816 // For input: (tuple (gte R1) (gte R1) O2)
4817 // Expected output: [R1, O2]
4818 //
4819 // For input: (tuple R1 R2 O2)
4820 // Expected output: [R1, R2, O2]
4821 //
4822 // For input: (tuple (gte R1) (gte R1) R2 O3)
4823 // Expected output: [R1, R2, O3]
4824 //
4825 // For input: R1
4826 // Expected output: [R1]
GetFusionRoots(HloComputation * computation)4827 static std::vector<HloInstruction*> GetFusionRoots(
4828     HloComputation* computation) {
4829   std::vector<HloInstruction*> out;
4830   GetFusionRootsRec(computation->root_instruction(), out);
4831   return out;
4832 }
4833 
4834 // Divides `num_reduces` reduces into groups. Different groups will be executed
4835 // in parallel. Generally speaking, we'd like to run the reduce instructions
4836 // in parallel without incurring too much recomputation overhead. The current
4837 // heuristic is to place reduce instructions who share nothing or only
4838 // (broadcasted) scalars/constants into different groups; otherwise, they are
4839 // placed in the same group. Non-reduce instructions always go with the reduce
4840 // instructions into the same group so long as they share any predecessors.
GroupDisjointReductions(HloComputation * fused_computation)4841 std::vector<std::vector<HloInstruction*>> GroupDisjointReductions(
4842     HloComputation* fused_computation) {
4843   const Shape& root_shape = fused_computation->root_instruction()->shape();
4844   int num_fusion_outputs =
4845       fused_computation->root_instruction()->opcode() == HloOpcode::kTuple
4846           ? root_shape.tuple_shapes_size()
4847           : 1;
4848   CHECK_NE(0, num_fusion_outputs);
4849   if (num_fusion_outputs == 1) {
4850     return {{fused_computation->root_instruction()}};
4851   }
4852 
4853   std::vector<HloInstruction*> roots = GetFusionRoots(fused_computation);
4854   HloInstructionMap<tensorflow::UnionFind<HloInstruction*>> disjoint_sets;
4855 
4856   for (HloInstruction* root : roots) {
4857     disjoint_sets[root].Get() = root;
4858   }
4859 
4860   std::unique_ptr<HloReachabilityMap> reachability_map =
4861       HloReachabilityMap::Build(fused_computation);
4862   for (HloInstruction* instr : fused_computation->instructions()) {
4863     std::vector<HloInstruction*> reached_output_ids;
4864     bool added_to_reduce = false;
4865     for (HloInstruction* output : roots) {
4866       if (HloOpcode::kReduce == output->opcode() &&
4867           (IsBroadcastedConstantOrScalar(*instr))) {
4868         if (added_to_reduce) {
4869           // Do not group more than one output reduce instructions through
4870           // broadcasted constants or scalars, as the recomputation should be
4871           // acceptable.
4872           VLOG(3) << "Skip broadcasted constant or scalar "
4873                   << instr->ToString();
4874           continue;
4875         }
4876       }
4877       // Now group output instructions if they have common predecessors.
4878       if (reachability_map->IsReachable(instr, output)) {
4879         VLOG(3) << "Reaching " << output->ToString() << " from "
4880                 << instr->ToString();
4881         reached_output_ids.push_back(output);
4882         if (HloOpcode::kReduce == output->opcode()) {
4883           added_to_reduce = true;
4884         }
4885       }
4886     }
4887     for (size_t j = 1; j < reached_output_ids.size(); ++j) {
4888       disjoint_sets[reached_output_ids[0]].Merge(
4889           &disjoint_sets[reached_output_ids[j]]);
4890     }
4891   }
4892 
4893   // Place output instructions in the same set into the same group.
4894   HloInstructionMap<std::vector<HloInstruction*>> groups;
4895   for (HloInstruction* root : roots) {
4896     groups[disjoint_sets[root].Get()].push_back(root);
4897   }
4898 
4899   std::vector<std::vector<HloInstruction*>> ret;
4900   absl::c_for_each(
4901       groups, [&](auto& iter) { ret.emplace_back(std::move(iter.second)); });
4902   return ret;
4903 }
4904 
4905 }  // namespace
4906 
EmitUnnestedReduction(mlir::lmhlo::FusionOp fusion)4907 Status IrEmitterUnnested::EmitUnnestedReduction(mlir::lmhlo::FusionOp fusion) {
4908   llvm::SmallVector<mlir::Operation*> fusion_roots = fusion.getFusionRoots();
4909 
4910   TF_ASSIGN_OR_RETURN(HloComputation * fused_computation,
4911                       GetOrCreateSubComputationFromRegion(&fusion.getRegion(),
4912                                                           /*is_fusion=*/true));
4913 
4914   // Group disjoint reductions in groups, to be executed in parallel.
4915   std::vector<std::vector<HloInstruction*>> instr_index_groups =
4916       GroupDisjointReductions(fused_computation);
4917 
4918   VLOG(2) << StrCat("Generate in ", instr_index_groups.size(), " groups for ",
4919                     MlirToString(fusion));
4920 
4921   mlir::mhlo::ReduceOp first_reduce = mlir::cast<mlir::mhlo::ReduceOp>(
4922       *absl::c_find_if(fusion_roots, [](mlir::Operation* op) {
4923         return IsReductionFromOrToContiguousDimensions(op);
4924       }));
4925   TF_ASSIGN_OR_RETURN(ReductionCodegenInfo reduction_info,
4926                       ComputeReductionCodegenInfo(fusion, first_reduce));
4927   const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme();
4928 
4929   // block_y_count is set to instr_index_groups.size(), so that each reduction
4930   // group can be run in parallel by a different BlockIdy.
4931   LaunchDimensions launch_dimensions(
4932       {/*x=*/tiling_scheme.GetNumberOfBlocksPhysical(),
4933        /*y=*/static_cast<int64_t>(instr_index_groups.size()),
4934        /*z=*/1},
4935       {/*x=*/tiling_scheme.GetNumThreadsPerBlockPhysical(), /*y=*/1, /*z=*/1});
4936   VLOG(3) << "Launch dimensions of " << mlir::GetNameFromLoc(fusion.getLoc())
4937           << launch_dimensions.ToString();
4938 
4939   std::vector<llvm_ir::IrArray> ir_arrays;
4940   TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> kernel_thunk,
4941                       BuildKernelThunk(fusion, Thunk::ThunkInfo(), &ir_arrays,
4942                                        launch_dimensions));
4943 
4944   GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
4945                                           GetNestedComputer());
4946   FusedIrEmitter fused_emitter(elemental_emitter);
4947   CHECK_LT(fused_computation->num_parameters(), ir_arrays.size());
4948   for (int i = 0; i < fused_computation->num_parameters(); i++) {
4949     llvm_ir::IrArray ir_array = ir_arrays[i];
4950     HloInstruction* fused_operand = fused_computation->parameter_instruction(i);
4951     fused_emitter.BindGenerator(
4952         *fused_operand,
4953         [this, ir_array, fused_operand](const llvm_ir::IrArray::Index& index) {
4954           return ir_array.EmitReadArrayElement(index, &b_,
4955                                                fused_operand->name());
4956         });
4957   }
4958 
4959   // Get outputs.
4960   ReductionOutputMap result_ir_arrays;
4961 
4962   // Skip all parameter buffers first.
4963   int ir_arrays_idx = fused_computation->num_parameters();
4964   std::vector<HloInstruction*> roots = GetFusionRoots(fused_computation);
4965   for (HloInstruction* root : roots) {
4966     int get_num_results = GetNumOutputs(root->shape());
4967     result_ir_arrays[root] =
4968         absl::MakeSpan(ir_arrays).subspan(ir_arrays_idx, get_num_results);
4969     ir_arrays_idx += get_num_results;
4970   }
4971 
4972   // We always use the first reduce as representative to construct
4973   // ReductionCodegenInfo, since all the reductions are required to have the
4974   // same shape and layout as verified by `IsFusedReductionOutputConsistent()`.
4975   TF_ASSIGN_OR_RETURN(ReductionCodegenInfo reduction_codegen_info,
4976                       ComputeReductionCodegenInfo(fusion, first_reduce));
4977 
4978   KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
4979 
4980   // Use raw block_id_y to select the i-th parallel reduction to run. Using
4981   // block_id_y instead of block_id_x simplifies the index calculation
4982   // for reduction code generation as the block_id_y is orthogonal to
4983   // the indices used within the reductions.
4984   llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic(
4985       gpu::TargetIntrinsicID::kBlockIdy, {}, {}, &b_);
4986   llvm_ir::AddRangeMetadata(0, instr_index_groups.size(),
4987                             llvm::cast<llvm::Instruction>(raw_block_id_y));
4988   for (int i = 0; i < instr_index_groups.size(); ++i) {
4989     TF_RETURN_IF_ERROR(ksl.IfWithStatus(
4990         StrCat("reduce-group-", i),
4991         b_.CreateICmpEQ(raw_block_id_y, b_.getInt32(i)), [&] {
4992           return EmitIRForReduction(
4993               fusion, instr_index_groups[i], fused_emitter, result_ir_arrays,
4994               reduction_codegen_info, GetShape(first_reduce->getOperand(0)));
4995         }));
4996   }
4997 
4998   ThunkSequence thunks;
4999   if (!reduction_codegen_info.IsRaceFree()) {
5000     for (int i = 0; i < fusion_roots.size(); ++i) {
5001       mlir::Operation* output_instruction = fusion_roots[i];
5002       if (IsReductionFromOrToContiguousDimensions(output_instruction)) {
5003         TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
5004                             BuildFusedInitializerThunk(fusion, i));
5005         thunks.push_back(std::move(initializer_thunk));
5006       }
5007     }
5008   }
5009 
5010   thunks.push_back(std::move(kernel_thunk));
5011   auto sequential_thunk = std::make_unique<SequentialThunk>(
5012       GetThunkInfo(fusion), std::move(thunks));
5013   AddThunkToThunkSequence(std::move(sequential_thunk));
5014 
5015   return OkStatus();
5016 }
5017 
5018 // Emits code for slices based on the below structure. An if statement with
5019 // a guarding condition is generated for each ROOT slice.
5020 //
5021 // Pseudo code:
5022 //
5023 // Compute values of slice input operands
5024 //
5025 // Compute guarding_cond0
5026 // if (guarding_cond0) {
5027 //   Write to output of slice0
5028 // }
5029 //
5030 // Compute guarding_cond1
5031 // if (guarding_cond1) {
5032 //   Write to output of slice1
5033 // }
5034 //
EmitElementForInputFusibleSlices(const HloComputation * fused_computation,absl::Span<const llvm_ir::IrArray> ir_arrays,const llvm_ir::IrArray::Index & index)5035 Status IrEmitterUnnested::EmitElementForInputFusibleSlices(
5036     const HloComputation* fused_computation,
5037     absl::Span<const llvm_ir::IrArray> ir_arrays,
5038     const llvm_ir::IrArray::Index& index) {
5039   VLOG(10) << "Emitting slice input fusion for "
5040            << fused_computation->ToString();
5041 
5042   HloInstruction* slice_or_tuple = fused_computation->root_instruction();
5043   auto slice_instructions = [&]() -> absl::Span<HloInstruction* const> {
5044     if (slice_or_tuple->opcode() == HloOpcode::kSlice) {
5045       return absl::Span<HloInstruction* const>(&slice_or_tuple, 1);
5046     }
5047     CHECK_EQ(slice_or_tuple->opcode(), HloOpcode::kTuple);
5048     return slice_or_tuple->operands();
5049   }();
5050 
5051   // Emit input operand values of slices.
5052   std::vector<llvm::Value*> input_ir_values;
5053   GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
5054                                      GetNestedComputer());
5055   FusedIrEmitter fused_emitter(elem_emitter);
5056   for (int i = 0; i < fused_computation->num_parameters(); i++) {
5057     fused_emitter.BindGenerator(
5058         *fused_computation->parameter_instruction(i),
5059         [this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
5060           return ir_arrays[i].EmitReadArrayElement(index, &b_);
5061         });
5062   }
5063   for (const HloInstruction* slice : slice_instructions) {
5064     auto input_generator = *fused_emitter.GetGenerator(*slice->operand(0));
5065     input_ir_values.push_back(input_generator(index).ValueOrDie());
5066   }
5067 
5068   // Emit for slice_instructions.
5069   KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
5070   for (int64_t i = 0; i < slice_instructions.size(); ++i) {
5071     HloInstruction* slice = slice_instructions[i];
5072 
5073     // guarding_cond := index >= start && index < limit, for each dim.
5074     std::vector<llvm::Value*> index_within_ranges;
5075     for (size_t dim = 0; dim < slice->slice_starts().size(); ++dim) {
5076       CHECK_EQ(slice->slice_strides(dim), 1);
5077       auto larger_or_equal_than_start = b_.CreateICmpSGE(
5078           index.multidim()[dim],
5079           index.GetConstantWithIndexType(slice->slice_starts(dim)));
5080       llvm::Value* smaller_than_limit = b_.CreateICmpSLT(
5081           index.multidim()[dim],
5082           index.GetConstantWithIndexType(slice->slice_limits(dim)));
5083       llvm::Value* within_range =
5084           b_.CreateAnd(larger_or_equal_than_start, smaller_than_limit);
5085       index_within_ranges.push_back(within_range);
5086     }
5087     llvm::Value* guarding_cond = b_.CreateAnd(index_within_ranges);
5088 
5089     auto emit_slice_elem_func = [&] {
5090       const std::vector<llvm::Value*>& src_multidim = index.multidim();
5091       std::vector<llvm::Value*> dst_multidim(src_multidim.size());
5092       for (size_t dim = 0; dim < src_multidim.size(); ++dim) {
5093         dst_multidim[dim] =
5094             Sub(src_multidim[dim],
5095                 index.GetConstantWithIndexType(slice->slice_starts(dim)));
5096       }
5097       llvm_ir::IrArray src_ir_array =
5098           ir_arrays[fused_computation->num_parameters() + i];
5099       IrArray::Index slice_dst_index(dst_multidim, slice->shape(),
5100                                      index.GetType());
5101       src_ir_array.EmitWriteArrayElement(slice_dst_index, input_ir_values[i],
5102                                          &b_);
5103     };
5104 
5105     ksl.If(StrCat("slice", i), guarding_cond, emit_slice_elem_func);
5106   }
5107   return OkStatus();
5108 }
5109 
EmitInputFusibleNonStridedSlices(mlir::Operation * op)5110 Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices(
5111     mlir::Operation* op) {
5112   auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(op);
5113 
5114   constexpr int unroll_factor = 1;
5115 
5116   TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation,
5117                       GetOrCreateSubComputationFromRegion(&fusion.getRegion(),
5118                                                           /*is_fusion=*/true));
5119 
5120   TF_ASSIGN_OR_RETURN(Shape element_shape,
5121                       GetConsistentInputShapeForRootSlices(fused_computation));
5122   TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
5123                       CalculateLaunchDimensions(
5124                           element_shape, ir_emitter_context_->gpu_device_info(),
5125                           {unroll_factor}));
5126 
5127   std::vector<llvm_ir::IrArray> ir_arrays;
5128   TF_ASSIGN_OR_RETURN(auto kernel_thunk,
5129                       BuildKernelThunk(fusion, GetThunkInfo(op), &ir_arrays,
5130                                        launch_dimensions));
5131 
5132   Status emit_status =
5133       ParallelLoopEmitter(
5134           [&](const llvm_ir::IrArray::Index index) -> Status {
5135             return EmitElementForInputFusibleSlices(fused_computation,
5136                                                     ir_arrays, index);
5137           },
5138           element_shape, launch_dimensions, &b_)
5139           .EmitLoop(IrName(GetIrNameFromLoc(fusion.getLoc())),
5140                     GetIndexTypeForKernel(
5141                         fusion, launch_dimensions.launch_bound(), &b_));
5142 
5143   thunk_sequence_.emplace_back(std::move(kernel_thunk));
5144 
5145   return emit_status;
5146 }
5147 
EmitDynamicUpdateSlice(mlir::lmhlo::FusionOp fusion_op,const HloComputation * fused_computation)5148 Status IrEmitterUnnested::EmitDynamicUpdateSlice(
5149     mlir::lmhlo::FusionOp fusion_op, const HloComputation* fused_computation) {
5150   // Fusion node with dynamic-update-slice as the root where the op's input
5151   // (i.e. array to update) shares the same slice as its output.  In this case
5152   // we have a special algorithm that modifies the output in place without
5153   // touching the un-updated elements.
5154   CHECK_EQ(1, GetHloOutputs(fusion_op).size());
5155 
5156   // Shape of the dynamic-update-slice's "update" operand.
5157   Shape update_shape =
5158       fused_computation->root_instruction()->operand(1)->shape();
5159 
5160   TF_ASSIGN_OR_RETURN(
5161       LaunchDimensions launch_dimensions,
5162       CalculateLaunchDimensions(update_shape,
5163                                 ir_emitter_context_->gpu_device_info()));
5164 
5165   // Set up kernel thunk and fused ir emitter.
5166   std::vector<llvm_ir::IrArray> ir_arrays;
5167   TF_ASSIGN_OR_RETURN(auto fusion_thunk,
5168                       BuildKernelThunk(fusion_op, GetThunkInfo(fusion_op),
5169                                        &ir_arrays, launch_dimensions));
5170   AddThunkToThunkSequence(std::move(fusion_thunk));
5171 
5172   GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
5173                                           GetNestedComputer());
5174 
5175   FusedIrEmitter fused_emitter(elemental_emitter);
5176 
5177   for (int i = 0; i < fused_computation->num_parameters(); i++) {
5178     auto fused_operand = fused_computation->parameter_instruction(i);
5179     fused_emitter.BindGenerator(
5180         *fused_operand, [this, &ir_arrays, i,
5181                          fused_operand](const llvm_ir::IrArray::Index& index) {
5182           return ir_arrays[i].EmitReadArrayElement(index, &b_,
5183                                                    fused_operand->name());
5184         });
5185   }
5186 
5187   // Array to write into.  Because this is an in-place operation, this is the
5188   // same as operand 0's array.
5189   const IrArray& output_array = ir_arrays.back();
5190 
5191   return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace(
5192       fused_computation, output_array, &fused_emitter, launch_dimensions, &b_);
5193 }
5194 
EmitScatter(mlir::lmhlo::FusionOp fusion_op,const HloComputation * fused_computation)5195 Status IrEmitterUnnested::EmitScatter(mlir::lmhlo::FusionOp fusion_op,
5196                                       const HloComputation* fused_computation) {
5197   auto* root = fused_computation->root_instruction();
5198 
5199   ThunkSequence thunks;
5200   // The initialization from 'operand' is using different loop bounds, so
5201   // emit it in a separate kernel. Treat it like a loop fusion, writing to
5202   // the output buffer.
5203   {
5204     auto unroll_factor = ComputeMaxUnrollFactor(fusion_op, hlo_module_config_);
5205     const Shape& element_shape = root->shape();
5206     TF_ASSIGN_OR_RETURN(
5207         LaunchDimensions launch_dimensions,
5208         CalculateLaunchDimensions(element_shape,
5209                                   ir_emitter_context_->gpu_device_info(),
5210                                   {unroll_factor, /*few_waves=*/false}));
5211 
5212     std::vector<llvm_ir::IrArray> ir_arrays;
5213     TF_ASSIGN_OR_RETURN(auto operand_thunk,
5214                         BuildKernelThunk(fusion_op, Thunk::ThunkInfo(),
5215                                          &ir_arrays, launch_dimensions));
5216     thunks.push_back(std::move(operand_thunk));
5217 
5218     GpuElementalIrEmitter operand_elemental_emitter(hlo_module_config_, module_,
5219                                                     &b_, GetNestedComputer());
5220     FusedIrEmitter operand_fused_emitter(operand_elemental_emitter);
5221     for (int i = 0; i < fused_computation->num_parameters(); i++) {
5222       auto fused_operand = fused_computation->parameter_instruction(i);
5223       operand_fused_emitter.BindGenerator(
5224           *fused_operand,
5225           [this, &ir_arrays, i, fused_operand](llvm_ir::IrArray::Index index) {
5226             return ir_arrays[i].EmitReadArrayElement(index, &b_,
5227                                                      fused_operand->name());
5228           });
5229     }
5230     TF_ASSIGN_OR_RETURN(auto generator,
5231                         operand_fused_emitter.GetGenerator(*root->operand(0)));
5232 
5233     TF_RETURN_IF_ERROR(
5234         ParallelLoopEmitter(generator, {ir_arrays.back()}, launch_dimensions,
5235                             &b_, {unroll_factor})
5236             .EmitLoop(IrName(GetIrNameFromLoc(fusion_op.getLoc())),
5237                       GetIndexTypeForKernel(
5238                           fusion_op, launch_dimensions.launch_bound(), &b_)));
5239   }
5240 
5241   // Now build the actual scatter, reading and writing to the freshly
5242   // filled output buffer.
5243   {
5244     const Shape& updates_shape = root->operand(2)->shape();
5245     TF_ASSIGN_OR_RETURN(
5246         LaunchDimensions launch_dimensions,
5247         CalculateLaunchDimensions(updates_shape,
5248                                   ir_emitter_context_->gpu_device_info()));
5249     std::vector<llvm_ir::IrArray> ir_arrays;
5250     TF_ASSIGN_OR_RETURN(auto scatter_thunk,
5251                         BuildKernelThunk(fusion_op, Thunk::ThunkInfo(),
5252                                          &ir_arrays, launch_dimensions));
5253     thunks.push_back(std::move(scatter_thunk));
5254     // Spin up a new fused emitter for the scatter kernel and emit it.
5255     GpuElementalIrEmitter scatter_elemental_emitter(hlo_module_config_, module_,
5256                                                     &b_, GetNestedComputer());
5257     FusedIrEmitter scatter_fused_emitter(scatter_elemental_emitter);
5258     for (int i = 0; i < fused_computation->num_parameters(); i++) {
5259       auto fused_operand = fused_computation->parameter_instruction(i);
5260       scatter_fused_emitter.BindGenerator(
5261           *fused_operand,
5262           [this, &ir_arrays, i, fused_operand](llvm_ir::IrArray::Index index) {
5263             return ir_arrays[i].EmitReadArrayElement(index, &b_,
5264                                                      fused_operand->name());
5265           });
5266     }
5267 
5268     TF_ASSIGN_OR_RETURN(const auto dim_numbers,
5269                         mlir::LhloDialectEmitter::GetScatterDimensionNumbers(
5270                             root, fusion_op.getContext()));
5271 
5272     ScatterDescriptor desc;
5273     desc.name = IrName(root);
5274     desc.operand_shape = root->operand(0)->shape();
5275     desc.scatter_indices_shape = root->operand(1)->shape();
5276     desc.updates_shape = updates_shape;
5277     desc.dim_numbers = dim_numbers;
5278     desc.unique_indices = root->unique_indices();
5279     desc.update_computation = root->called_computations()[0];
5280     desc.output = ir_arrays.back();
5281     TF_ASSIGN_OR_RETURN(desc.scatter_indices_gen,
5282                         scatter_fused_emitter.GetGenerator(*root->operand(1)));
5283     TF_ASSIGN_OR_RETURN(desc.updates_gen,
5284                         scatter_fused_emitter.GetGenerator(*root->operand(2)));
5285     desc.get_index_type = [&](int64_t launch_size) {
5286       return GetIndexTypeForKernel(root, launch_size, &b_);
5287     };
5288 
5289     TF_RETURN_IF_ERROR(
5290         EmitScatter(desc, thunks.back().get(), launch_dimensions));
5291   }
5292   AddThunkToThunkSequence(std::make_unique<SequentialThunk>(
5293       GetThunkInfo(fusion_op), std::move(thunks)));
5294   return OkStatus();
5295 }
5296 
EmitOp(mlir::Operation * op)5297 Status IrEmitterUnnested::EmitOp(mlir::Operation* op) {
5298   if (mlir::isa<mlir::func::ConstantOp, mlir::arith::ConstantOp,
5299                 mlir::memref::ViewOp, mlir::memref::ReinterpretCastOp,
5300                 mlir::func::ReturnOp, mlir::lmhlo::TerminatorOp>(op)) {
5301     return OkStatus();
5302   }
5303 
5304   if (mlir::isa<mlir::memref::GetGlobalOp>(op)) {
5305     return EmitConstant(op);
5306   }
5307 
5308   if (auto call = mlir::dyn_cast<mlir::lmhlo::CustomCallOp>(op)) {
5309     if (call.getCallTargetName() == "PadToStatic") {
5310       return EmitPadToStatic(op);
5311     }
5312     if (call.getCallTargetName() == "SliceToDynamic") {
5313       return EmitSliceToDynamic(op);
5314     }
5315 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
5316     if (call.getCallTargetName() == kTriangularSolveCallTarget) {
5317       return EmitTriangularSolveCustomCall(op);
5318     }
5319 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
5320 
5321     return EmitCustomCallThunk(op);
5322   }
5323 
5324   if (mlir::isa<mlir::lmhlo_gpu::GEMMOp>(op)) {
5325     return EmitGemmThunk(op);
5326   }
5327 
5328 #if GOOGLE_CUDA
5329   if (mlir::isa<mlir::lmhlo_gpu::CublasLtMatmulOp>(op)) {
5330     return EmitCublasLtMatmulThunk(op);
5331   }
5332 #endif  // GOOGLE_CUDA
5333 
5334   if (mlir::isa<mlir::lmhlo_gpu::ConvForwardOp,
5335                 mlir::lmhlo_gpu::ConvForwardFusedOp,
5336                 mlir::lmhlo_gpu::ConvForwardFusedSideInputOp,
5337                 mlir::lmhlo_gpu::ConvBackwardFilterOp,
5338                 mlir::lmhlo_gpu::ConvBackwardInputOp>(op)) {
5339     return EmitConvolutionThunk(op);
5340   }
5341 
5342 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
5343   if (mlir::isa<mlir::lmhlo_gpu::CholeskyOp>(op)) {
5344     return EmitCholeskyThunk(op);
5345   }
5346 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
5347 
5348   if (mlir::isa<mlir::lmhlo::FftOp>(op)) {
5349     return EmitFftThunk(op);
5350   }
5351 
5352   if (mlir::isa<mlir::lmhlo::TriangularSolveOp>(op)) {
5353     return InternalError(
5354         "TriangularSolve is implemented as a custom-call; we do not expect to "
5355         "lower a true HLO TriangularSolve op.");
5356   }
5357 
5358   if (mlir::isa<mlir::lmhlo::FusionOp>(op)) {
5359     return EmitFusion(op);
5360   }
5361 
5362   if (mlir::isa<mlir::lmhlo::SelectAndScatterOp>(op)) {
5363     return EmitSelectAndScatter(op);
5364   }
5365 
5366   if (mlir::isa<mlir::lmhlo::RngGetAndUpdateStateOp>(op)) {
5367     return EmitRngGetAndUpdateState(op);
5368   }
5369 
5370   if (mlir::isa<mlir::lmhlo::ScatterOp>(op)) {
5371     return EmitScatter(op);
5372   }
5373 
5374   if (mlir::isa<mlir::lmhlo::SortOp>(op)) {
5375     return EmitSort(op);
5376   }
5377 
5378   if (mlir::isa<mlir::lmhlo::ReplicaIdOp>(op)) {
5379     return EmitReplicaOrPartitionId<ReplicaIdThunk, mlir::lmhlo::ReplicaIdOp>(
5380         op);
5381   }
5382 
5383   if (mlir::isa<mlir::lmhlo::PartitionIdOp>(op)) {
5384     return EmitReplicaOrPartitionId<PartitionIdThunk,
5385                                     mlir::lmhlo::PartitionIdOp>(op);
5386   }
5387 
5388   if (mlir::isa<mlir::lmhlo::CollectivePermuteOp>(op)) {
5389     return EmitCollectivePermute(op);
5390   }
5391 
5392   if (mlir::isa<mlir::lmhlo::AllGatherOp>(op)) {
5393     return EmitNcclThunk<NcclAllGatherThunk, mlir::lmhlo::AllGatherOp>(op);
5394   }
5395 
5396   if (mlir::isa<mlir::lmhlo::AllReduceOp>(op)) {
5397     return EmitNcclThunk<NcclAllReduceThunk, mlir::lmhlo::AllReduceOp>(op);
5398   }
5399 
5400   if (mlir::isa<mlir::lmhlo_gpu::AllReduceStartOp>(op)) {
5401     return EmitNcclThunk<NcclAllReduceStartThunk,
5402                          mlir::lmhlo_gpu::AllReduceStartOp>(op);
5403   }
5404 
5405   if (mlir::isa<mlir::lmhlo_gpu::AllReduceDoneOp>(op)) {
5406     return EmitAllReduceDone(op);
5407   }
5408 
5409   if (mlir::isa<mlir::lmhlo::ReduceScatterOp>(op)) {
5410     return EmitNcclThunk<NcclReduceScatterThunk, mlir::lmhlo::ReduceScatterOp>(
5411         op);
5412   }
5413 
5414   if (mlir::isa<mlir::lmhlo::AllToAllOp>(op)) {
5415     return EmitNcclThunk<NcclAllToAllThunk, mlir::lmhlo::AllToAllOp>(op);
5416   }
5417 
5418   if (mlir::isa<mlir::lmhlo::InfeedOp>(op)) {
5419     return EmitInfeed(op);
5420   }
5421 
5422   if (mlir::isa<mlir::lmhlo::OutfeedOp>(op)) {
5423     return EmitOutfeed(op);
5424   }
5425 
5426   if (mlir::isa<mlir::lmhlo::CaseOp>(op)) {
5427     return EmitConditional(op);
5428   }
5429 
5430   if (mlir::isa<mlir::lmhlo::WhileOp>(op)) {
5431     return EmitWhile(op);
5432   }
5433 
5434   if (mlir::isa<mlir::gpu::LaunchFuncOp>(op)) {
5435     return EmitLaunchFunc(op);
5436   }
5437 
5438   // Remaining arith.constant ops are the gpu.launch_func dimensions as a result
5439   // of inlining the fusion region after lowering. They can safely be skipped
5440   // because constants have no side effects.
5441   if (mlir::isa<mlir::arith::ConstantOp>(op)) {
5442     return Status::OK();
5443   }
5444 
5445   return InternalError("Unrecognized op: %s", MlirToString(op));
5446 }
5447 
EmitLmhloRegion(mlir::Region * region)5448 Status IrEmitterUnnested::EmitLmhloRegion(mlir::Region* region) {
5449   for (mlir::Operation& op : llvm::make_early_inc_range(region->front())) {
5450     TF_RETURN_IF_ERROR(EmitOp(&op));
5451   }
5452   return OkStatus();
5453 }
5454 
GetDependentDialects(mlir::DialectRegistry & registry)5455 void IrEmitterUnnested::GetDependentDialects(mlir::DialectRegistry& registry) {
5456   registry.insert<mlir::arith::ArithmeticDialect, mlir::func::FuncDialect,
5457                   mlir::gpu::GPUDialect, mlir::lmhlo::LmhloDialect,
5458                   mlir::lmhlo_gpu::LmhloGpuDialect, mlir::mhlo::MhloDialect>();
5459   mlir::registerLLVMDialectTranslation(registry);
5460   mlir::registerNVVMDialectTranslation(registry);
5461   mlir::registerROCDLDialectTranslation(registry);
5462 }
5463 
GetThunkInfo(mlir::Operation * op)5464 Thunk::ThunkInfo IrEmitterUnnested::GetThunkInfo(mlir::Operation* op) {
5465   auto module = op->getParentOfType<mlir::ModuleOp>();
5466   // Include the HloModule's unique_id in the thunk's module name so that xprof
5467   // shows different modules differently, addressing b/202415436#comment24.
5468   // xprof calls this the "program_id".
5469   std::string unique_id_str;
5470   if (auto unique_id_attr =
5471           module->getAttrOfType<mlir::IntegerAttr>("mhlo.unique_id")) {
5472     unique_id_str = absl::StrFormat(",program_id=%d",
5473                                     unique_id_attr.getValue().getZExtValue());
5474   }
5475   Thunk::ThunkInfo thunk_info;
5476   thunk_info.profile_annotation = absl::StrFormat(
5477       "Thunk:#hlo_op=%s,hlo_module=%s%s#", mlir::GetNameFromLoc(op->getLoc()),
5478       mlir::GetNameFromLoc(module->getLoc()), unique_id_str);
5479   return thunk_info;
5480 }
5481 
5482 }  // namespace gpu
5483 }  // namespace xla
5484