xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
18 
19 #include <array>
20 #include <functional>
21 #include <string>
22 
23 #include "absl/container/inlined_vector.h"
24 #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
25 #include "tensorflow/compiler/xla/service/custom_call_status.h"
26 #include "tensorflow/compiler/xla/service/gpu/custom_call_thunk.h"
27 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
28 #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
29 #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h"
30 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
31 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
32 #include "tensorflow/compiler/xla/service/hlo_computation.h"
33 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
34 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
35 
36 namespace xla {
37 namespace gpu {
38 
39 struct BufferSlice {
40   // The root buffer to look at.
41   BufferAllocation::Slice buffer_slice;
42 
43   // The global constant name of the buffer, if it's a constant.
44   std::string constant_name;
45 
46   // The buffer is modified by the kernel.
47   bool written = false;
48 
49   Shape shape;
50 };
51 
52 // Emits LLVM IR for an "unnested computation".
53 //
54 // An unnested computation is an HloComputation which you run by executing one
55 // or more kernels for each HloInstruction it contains.  Examples of unnested
56 // computations:
57 //
58 //  - An HloModule's root computation,
59 //  - The body of an HLO while loop,
60 //  - The true/false computation of an HLO conditional.
61 //
62 // Note the opportunity for confusion -- the while loop's computation is nested
63 // within the root computation, but it's emitted using IrEmitterUnnested!  Don't
64 // think about it too hard.
65 //
66 // Examples of things that are not unnested computations:
67 //
68 //  - The body of a fusion node.  IrEmitterUnnested emits the relevant code
69 //    within a kernel function using FusedIrEmitter.  (FusedIrEmitter is not
70 //    really an IrEmitter, but is more an "IR generator generator".)
71 //
72 class IrEmitterUnnested : public IrEmitter {
73  public:
74   // Contains threading information. Note that for performance we might apply
75   // thread id "scaling" where the physical thread id (to achieve good SM
76   // occupancy) will differ from logical thread id. This struct contains
77   // logical thread ids, along with meta-information about the scaling applied.
78   struct ThreadIdInfo {
ThreadIdInfoThreadIdInfo79     ThreadIdInfo(llvm::Value* thread_id, llvm::Value* thread_id_x,
80                  llvm::Value* thread_id_y, llvm::Value* lane_id,
81                  llvm::Value* block_id, llvm::Value* scaling)
82         : thread_id(thread_id),
83           thread_id_x(thread_id_x),
84           thread_id_y(thread_id_y),
85           lane_id(lane_id),
86           block_id(block_id),
87           scaling(scaling) {}
88 
89     llvm::Value* thread_id;
90 
91     // X-coordinate calculated from thread id: `thread_id % num_threads_x`
92     llvm::Value* thread_id_x;
93 
94     // Y-coordinate calculated from thread id: `thread_id / num_threads_x`
95     llvm::Value* thread_id_y;
96 
97     // Lane id: `thread_id % WarpSize`
98     llvm::Value* lane_id;
99 
100     // Block id.
101     llvm::Value* block_id;
102 
103     // Emits GEP into a shared memory, taking virtual thread scaling into
104     // account. Automatically inserts the first zero required by LLVM GEP.
105     // Defined on ThreadIdInfo to keep `scaling` private.
106     //
107     // Same semantics as CreateInBoundsGEP.
108     llvm::Value* GEPIntoSharedMemory(
109         llvm::IRBuilder<>* b, llvm::GlobalVariable* shared,
110         absl::Span<llvm::Value* const> idx_major_to_minor,
111         const llvm::Twine& name = "") const;
112 
113     // Calculuate the pointee type of the llvm::Value returned by
114     // GEPIntoSharedMemory
115     llvm::Type* GEPIntoSharedMemoryType(
116         llvm::GlobalVariable* shared,
117         absl::Span<llvm::Value* const> idx_major_to_minor) const;
118 
119    private:
120     llvm::Value* scaling;
121   };
122 
platform_name()123   absl::string_view platform_name() const {
124     return ir_emitter_context_->platform_name();
125   }
126 
127   using ValueVector3 = std::array<llvm::Value*, 3>;
128 
129   // A function object to generate code to process one element in a tile.
130   //
131   // index: the index for the first output element of the current thread.
132   // y_loc: The y coordinate within a tile.
133   // x_loc: The x coordinate within a tile.
134   // x_iter_num: When a thread process N elements in the X dimension, x_iter_num
135   //             has a value of 0..N-1 to identify the element being process.
136   using EmitElementFunction = std::function<void(
137       const ThreadIdInfo& thread_id_info, const llvm_ir::IrArray::Index& index,
138       llvm::Value* y_loc, llvm::Value* x_loc, llvm::Value* x_iter_num)>;
139 
140   using ConstantGenerator = std::function<llvm::Value*(int64_t)>;
141 
142   // A function to generate the code to emit the entire tile.
143   //
144   // index: Absolute coordinate of the start of the tile in input.
145   // tile_dimensions: Size of the tile
146   using TileElementGenerator = std::function<void(
147       const ThreadIdInfo& thread_id_info, const llvm_ir::IrArray::Index& index,
148       ValueVector3 tile_dimensions)>;
149 
150   // Fusion root -> array of indexes, one per reduction output.
151   using ReductionOutputMap =
152       ConstHloInstructionMap<absl::Span<llvm_ir::IrArray const>>;
153 
154   using ExtraOutputGensMap = ConstHloInstructionMap<llvm_ir::ElementGenerator>;
155 
156   IrEmitterUnnested(const IrEmitterUnnested&) = delete;
157   IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete;
158 
159   static StatusOr<std::unique_ptr<IrEmitterUnnested>> Create(
160       const HloModuleConfig& hlo_module_config,
161       IrEmitterContext* ir_emitter_context);
162 
163   // Transfers the ownship of thunk_sequence_ out.
ConsumeThunkSequence()164   std::unique_ptr<ThunkSequence> ConsumeThunkSequence() {
165     return std::make_unique<ThunkSequence>(std::move(thunk_sequence_));
166   }
167 
168   // Emits code for the given LMHLO region.
169   //
170   // Also populates related information to 'ir_emitter_context_' for
171   // large-constant initializations. Large constants don't get initializers in
172   // the generated code and so must be initialized by XLA. The value of these
173   // constants will be stored in 'content'. Constants with initializers in the
174   // generated code will have empty 'content'.
175   Status EmitLmhloRegion(mlir::Region* region);
176 
177   static void GetDependentDialects(mlir::DialectRegistry& registry);
178 
179  private:
180   IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
181                     IrEmitterContext* ir_emitter_context);
182 
183   // IrEmitterUnnested handles the following instructions differently from
184   // IrEmitter. It also mixes in some special handling for custom kernels
185   // via the ThunkEmitter.
186   Status EmitConstant(mlir::Operation* op);
187 
188   Status EmitCopy(mlir::Operation* op);
189 
190   Status EmitConditional(mlir::Operation* op);
191   Status EmitConvolutionThunk(mlir::Operation* op);
192   Status EmitGemmThunk(mlir::Operation* op);
193 #if GOOGLE_CUDA
194   Status EmitCublasLtMatmulThunk(mlir::Operation* op);
195 #endif  // GOOGLE_CUDA
196 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
197   Status EmitCholeskyThunk(mlir::Operation* op);
198 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
199   Status EmitCustomCallThunk(mlir::Operation* op);
200   Status EmitFftThunk(mlir::Operation* op);
201   Status EmitFusion(mlir::Operation* op);
202   Status EmitLaunchFunc(mlir::Operation* op);
203   Status EmitLoopFusion(mlir::Operation* op);
204   Status EmitReduce(mlir::Operation* op);
205   Status EmitSelectAndScatter(mlir::Operation* op);
206   Status EmitWhile(mlir::Operation* op);
207   Status EmitInfeed(mlir::Operation* op);
208   Status EmitOutfeed(mlir::Operation* op);
209   Status EmitRngGetAndUpdateState(mlir::Operation* op);
210   Status EmitScatter(mlir::Operation* op);
211   Status EmitSort(mlir::Operation* op);
212 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
213   Status EmitTriangularSolveCustomCall(mlir::Operation* op);
214 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
215 
216   template <typename NcclThunkType, typename OpTy>
217   Status EmitNcclThunk(mlir::Operation* op);
218   Status EmitAllReduceDone(mlir::Operation* op);
219 
220   template <typename ThunkType, typename OpT>
221   Status EmitReplicaOrPartitionId(mlir::Operation* op);
222 
223   Status EmitCollectivePermute(mlir::Operation* op);
224 
225   Status EmitOp(mlir::Operation* op);
226 
227   static Thunk::ThunkInfo GetThunkInfo(mlir::Operation* op);
228 
229   Status EmitTargetElementLoop(
230       const HloInstruction& hlo,
231       const llvm_ir::ElementGenerator& body_emitter) override;
232 
233   // Add a owning Thunk object to the thunk sequence.
AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk)234   void AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk) {
235     thunk_sequence_.emplace_back(std::move(thunk));
236   }
237 
238   // Load data from potentially unaligned address. If address is offset by
239   // `alignment_bytes`, data is read in the unit of `alignment_bytes` to avoid
240   // memory read misalignment in CUDA; otherwise, the entire data are loaded
241   // from the given memory address.
242   //
243   //   address: the memory address to load data from.
244   //   data_type: the type of data to load.
245   //   alignment_bytes: the number of bytes required to align. The number of
246   //     bytes of the data_type must be divisible by alignment_bytes.
247   llvm::Value* CreateLoad(llvm::Value* address, llvm::Type* data_type,
248                           int alignment_bytes);
249 
250   // Store data at a potentially unaligned address. If the address is offset by
251   // `alignment_bytes`, data is stored in the unit of `alignment_bytes` to avoid
252   // memory write misalignment in CUDA; otherwise, the entire data is stored at
253   // the given memory address.
254   //
255   //   data: the data to be stored.
256   //   address: the memory address to store data.
257   //   alignment_bytes: the number of bytes required to align. The number of
258   //     bytes of the data_type must be divisible by alignment_bytes.
259   void CreateStore(llvm::Value* data, llvm::Value* address,
260                    int alignment_bytes);
261 
262   // Input = {static array, dynamic_dim0, dynamic_dim1}
263   // Output = {dynamic array(with dynamic dimension meta data at the end)}
264   // For a tensor with static dimension [2][<=5] and dynamic dimension [2][3]
265   // (`_` stands for padding)
266   // Input = {{1,2,3,_,_,4,5,6_,_}, 2, 3}
267   // Output = {{1,2,3,4,5,6,_,_,_,_,2,3}}
268 
269   // pseudo code for padToStatic on a 2d array
270   //   ```
271   // void padToStatic(int** input, int** output, int threads_per_block,
272   //                  int meta_data_offset, int max_num_element,
273   //                  int static_dim0_size, int static_dim1_size) {
274   //   int* source_array = input[0];
275   //   int* dest_array = output[0];
276 
277   //   // extract the dynamic dimension from the source array's metadata
278   //   int* dyn_dim0_size = source_array + meta_data_offset;
279   //   int* dyn_dim1_size = source_array + meta_data_offset + sizeof(int);
280 
281   //   // only one thread need to store the dynamic index
282   //   int thread_id = GetThreadId();
283   //   int block_id = GetBlockId();
284   //   if (thread_id == 0 && block_id == 0) {
285   //     *output[1] = *dyn_dim0_size;
286   //     *output[2] = *dyn_dim1_size;
287   //   }
288 
289   //   int dyn_element_total = 1;
290   //   dyn_element_total *= *dyn_dim0_size;
291   //   dyn_element_total *= *dyn_dim1_size;
292   //   linear_index = block_id * threads_per_block + thread_id;
293   //   if (linear_index < max_num_element) {
294   //     Index static_index =
295   //         delinerized(linerized_index, static_dim0_size, static_dim1_size);
296   //     if (linerized_index < dyn_element_total) {
297   //       Index dyn_index =
298   //           delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
299   //       dest_array[dyn_index.dim0][dyn_index.dim1] =
300   //           source_array[static_index.dim0][static_index.dim1];
301   //     }
302   //   }
303   //   return;
304   // }
305   //   ```
306   Status EmitPadToStatic(mlir::Operation* op);
307 
308   // Input = {dynamic array(with dynamic dimension meta data at the end)}
309   // Output = {static array, dynamic_dim0, dynamic_dim1}
310   // For a tensor with static dimension [2][<=5] and dynamic dimension [2][3]
311   // (`_` stands for padding)
312   // Input = {{1,2,3,4,5,6,_,_,_,_,2,3}}
313   // Output = {{1,2,3,_,_,4,5,6_,_}, 2, 3}
314 
315   // pseudo code for sliceToDynamic on a 2d array
316   //   ```
317   // void sliceToDynamic(int** input, int** output, int threads_per_block,
318   //                  int meta_data_offset, int max_num_element,
319   //                  int static_dim0_size, int static_dim1_size) {
320   //   int* source_array = input[0];
321   //   int* dest_array = output[0];
322 
323   //   // calculate the location where metadata needs to be inserted
324   //   int* dyn_dim0_size = dest_array + meta_data_offset;
325   //   int* dyn_dim1_size = dest_array + meta_data_offset + sizeof(int);
326 
327   //   // only one thread need to store the dynamic index
328   //   int thread_id = GetThreadId();
329   //   int block_id = GetBlockId();
330   //   if (thread_id == 0 && block_id == 0) {
331   //     *dyn_dim0_size = *output[1];
332   //     *dyn_dim1_size = *output[2];
333   //   }
334 
335   //   int dyn_element_total = 1;
336   //   dyn_element_total *= *dyn_dim0_size;
337   //   dyn_element_total *= *dyn_dim1_size;
338   //   linear_index = block_id * threads_per_block + thread_id;
339   //   if (linear_index < max_num_element) {
340   //     Index static_index =
341   //         delinerized(linerized_index, static_dim0_size, static_dim1_size);
342   //     if (linerized_index < dyn_element_total) {
343   //       Index dyn_index =
344   //           delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
345   //       dest_array[static_index.dim0][static_index.dim1] =
346   //           source_array[dyn_index.dim0][dyn_index.dim1];
347   //     }
348   //   }
349   //   return;
350   // }
351   //   ```
352   Status EmitSliceToDynamic(mlir::Operation* op);
353 
354   StatusOr<BufferAllocation::Slice> GetAllocationSlice(
355       mlir::Value v, std::string* constant_name = nullptr);
356 
ByteSizeOf(const Shape & shape)357   int64_t ByteSizeOf(const Shape& shape) const {
358     return llvm_ir::ByteSizeOf(
359         shape, ir_emitter_context_->llvm_module()->getDataLayout());
360   }
361 
362   // Builds the prototype of the IR kernel for `inst` and adds it to the module.
363   // This kernel takes as arguments pointers to the given buffer allocations.
364   llvm::Function* BuildKernelPrototype(
365       absl::string_view name, absl::Span<const BufferAllocation* const> args);
366 
367   // Helper for writing extra outputs from inside a reduce kernel.
368   Status EmitExtraOutputsForReduce(const ReductionOutputMap& result_ir_arrays,
369                                    const llvm_ir::IrArray::Index& index,
370                                    const ReductionCodegenInfo& reduction_info,
371                                    const ExtraOutputGensMap& extra_output_gens);
372 
373   // Generates code for reduction to contiguous dimensions.
374   //
375   // Row reduction uses the following algorithm described in CUDA-like
376   // pseudocode:
377   //
378   // ```
379   //  __global__ void reduce(int num_rows, float *in, float out) {
380   //    __shared__ float[32] cache;
381   //    int offset = blockDim.x * blockIdx.x + threadIdx.x;
382   //    if (offset >= num_rows) return;
383   //    int tile_bound = std::min(offset + kTileSizeX, num_rows);
384   //    float accum = 0;
385   //    for (int i=offset; i<num_rows; i+= blockDim.x) {
386   //      accum += in[i];
387   //    }
388   //    accum = warp_reduce(accum);
389   //    if (threadIdx.x % WarpSize == 0) {
390   //      cache[threadIdx.x / WarpSize] = accum;
391   //    }
392   //    __syncthreads();
393   //    if (threadIdx.x / WarpSize == 0) {
394   //      bool warp_exists = threadIdx.x < (blockDim.x / WarpSize);
395   //      float block_accum = warp_exists ? cache[threadIdx.x % WarpSize] : 0;
396   //      block_accum = warp_reduce(accum);
397   //      if (threadIdx.x == 0) {
398   //        out += block_accum;
399   //      }
400   //    }
401   //  }
402   // ```
403   //
404   // Column reduction uses the following algorithm:
405   //
406   // ```
407   // void reduce(float** in, float* out) {
408   //   __shared__ float[32][33] cache;
409   //   int thread_id = GetThreadId();
410   //   int block_id = GetBlockId();
411   //   int tile_size = 128;
412   //
413   //   float accum = 0;
414   //   for (int i=0; i<tile_size; i++) {
415   //     accum += in[thread_id.y * tile_size + i][block_id * 32 + thread_id.x];
416   //   }
417   //   cache[thread_id.x][thread_id.y] = accum;
418   //
419   //   __syncthreads();
420   //   accum = cache[thread_id.y][thread_id.x];
421   //   accum = warp_reduce(accum); // Sum all the values of `accum` in the same
422   //                               // warp.
423   //
424   //   if (thread_id.y % 32 == 0) {
425   //     out[block_id * 32 + thread_id.x] = accum;
426   //   }
427   // }
428   // ```
429   //
430   // Moreover, a heuristic is implemented to divide the reduce instructions
431   // into groups for parallelization (see `DivideOutputInstructionsIntoGroups`
432   // for details about the heuristic.) Reduce instructions in the same group
433   // will run sequentially while different groups will run in parallel.
434   //
435   // we use raw block_id_y to select the reduce groups for execution without
436   // complicating the index calculation in the code generation of the reduce
437   // instructions. In other words, a block_id_y is assigned to a group and so
438   // different groups can be run in parallel.
439   Status EmitUnnestedReduction(mlir::lmhlo::FusionOp fusion);
440 
441   // Computes the KernelMappingScheme for the reduce HLO and indicates whether
442   // the reduction is a row reduction. For an un-fused reduce op, unnested_hlo
443   // and first_reduce are the same instruction. For a kInput fusion,
444   // unnested_hlo is the fusion instruction while first_reduce is the first
445   // reduce op.
446   StatusOr<ReductionCodegenInfo> ComputeReductionCodegenInfo(
447       mlir::lmhlo::FusionOp fusion, mlir::mhlo::ReduceOp first_reduce);
448 
449   // Generates code for input-fusible slices.
450   //
451   // Prerequisite: ROOT is either a slice or a tuple of slices. The input shapes
452   // of all ROOT slices need to be the same while their output shapes can be
453   // different. On the other hand, the input ranges of slices can be
454   // overlapping. Further generalization/specialization when the needs are seen
455   // in the future.
456   Status EmitInputFusibleNonStridedSlices(mlir::Operation* op);
457 
458   Status EmitElementForInputFusibleSlices(
459       const HloComputation* fused_computation,
460       absl::Span<const llvm_ir::IrArray> ir_arrays,
461       const llvm_ir::IrArray::Index& index);
462 
463   // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in
464   // the process. Scatter indices are taken from `scatter_indices_gen`, updates
465   // from `updates_gen`. The output buffer is expected to have the operand
466   // values in it already. If unique_indices is false, we will use an atomic
467   // update. Using true for unique_indices behaves properly only when it is
468   // guaranteed that the indices to be updated do not overlap. The caller is
469   // responsible for ensuring this is the case.
470   Status EmitScatter(Thunk* thunk, mlir::lmhlo::ScatterOp scatter,
471                      const LaunchDimensions& launch_dimensions,
472                      const llvm_ir::IrArray& output,
473                      const llvm_ir::ElementGenerator& scatter_indices_gen,
474                      const llvm_ir::ElementGenerator& updates_gen,
475                      std::function<llvm::Type*(int64_t)> get_index_type);
476 
477   // Structure describing a scatter operation for IR emission.
478   // TODO(jurahul): Migrate element generators to use MLIR.
479   //                Migrate update_computation to be an MLIR Region.
480   struct ScatterDescriptor {
481     std::string name;
482     Shape operand_shape;
483     Shape scatter_indices_shape;
484     Shape updates_shape;
485     mlir::mhlo::ScatterDimensionNumbersAttr dim_numbers;
486     bool unique_indices;
487     const HloComputation* update_computation;
488     llvm_ir::IrArray output;
489     llvm_ir::ElementGenerator scatter_indices_gen;
490     llvm_ir::ElementGenerator updates_gen;
491     std::function<llvm::Type*(int64_t)> get_index_type;
492   };
493 
494   // Emits code for an in-place scatter using the provided scatter operation
495   // description.
496   Status EmitScatter(const ScatterDescriptor& desc, Thunk* thunk,
497                      const LaunchDimensions& launch_dimensions);
498 
499   // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel
500   // for the hlo instruction.
501   Status Emit021Transpose(TransposeDimsAndParams descr,
502                           mlir::lmhlo::FusionOp fusion);
503 
504   // Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose
505   // algorithm to improve the memory access patterns for the input parameters
506   // with a shape that is a 0-2-1 transpose of the output tensor shape. The
507   // caller is responsible for making sure that it is safe to apply the shared
508   // memory transpose on the input parameters.
509   //
510   //
511   // For the purpose of tiling, the output tensors have a logical shape of three
512   // components 0-2-1 while the relevant input parameters have a logical shape
513   // of three components 0-1-2 in the order major to minor. The x- and y-
514   // dimensions of the tensors are tiled in square tiles with an edge length
515   // `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads
516   // transposes one tile: each thread copies kTileSize/kNumRows elements from
517   // the input to a shared memory tile, then the otherwise "regular HLO kernel"
518   // reads from the shared memory instead of the original input.
519   //
520   // This is similar to the following CUDA algorithm in TensorFlow:
521   // https://goo.gl/MStRV6.
522   //
523   // `kTileSize` should usually be same as warp size. We currently choose 32 for
524   // `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`.
525   //
526   // TODO(b/33320379): Here each block transposes 1 tile. It may be more
527   // efficient to launch fewer blocks so each transposes many tiles.
528   Status EmitHlo021Tile(mlir::lmhlo::FusionOp fusion,
529                         absl::Span<const llvm_ir::IrArray> operand_arrays,
530                         absl::Span<const llvm_ir::IrArray> output_arrays,
531                         TransposeDimsAndParams descr,
532                         const TilingScheme& tiling_scheme,
533                         const LaunchDimensions& launch_dimensions);
534 
535   Status EmitScatter(mlir::lmhlo::FusionOp fusion_op,
536                      const HloComputation* fused_computation);
537 
538   Status EmitDynamicUpdateSlice(mlir::lmhlo::FusionOp fusion_op,
539                                 const HloComputation* fused_computation);
540 
541   struct TilingKernelInfo {
542     // Tiling bounds.
543     ValueVector3 output_tile_bounds;
544 
545     // Starting tile, as calculated from block id only.
546     llvm_ir::IrArray::Index tile_origin;
547 
548     // Thread meta-info.
549     ThreadIdInfo thread_id_info;
550   };
551 
552   // Emits a kernel for the hlo instruction using the given kernel mapping
553   // scheme.
554   StatusOr<TilingKernelInfo> EmitTilingKernel(
555       const TilingScheme& tiling_scheme, llvm::Type* index_ty,
556       const TileElementGenerator& tile_element_generator);
557 
558   // Emits code to iterate through a 2-dimensional tile with a given tile
559   // dimensions and given strides, and call the callback at each iteration.,
560   //
561   // thread_id_y` and `thread_id_x` are the intra-tile coordinates for
562   // the first element to process, and `index` is the index for the origin of
563   // the tile. Emits bounds check to ensure that each processed element
564   // is within the boundary defined by `tile_dimensions`.
565   //
566   // Rough pseudocode:
567   //
568   // Given: tile_dimensions, x_offset, y_offset
569   //
570   // for (y = 0; y < tile_dimensions[Y]; y += num_threads_y) {
571   //   for (x = 0; x < tile_dimensions[X]; x++) {
572   //
573   //     y_pos = y_offset + y
574   //     x_pos = x_offset + x * stride
575   //
576   //     if (x_loc < tile_width) {
577   //       emit_elem_function(y_offset + y, x_loc);
578   //     }
579   //   }
580   // }
581   //
582   void EmitTile(
583       const TilingScheme& tiling_scheme,
584       const llvm_ir::IrArray::Index& tile_origin_index,
585       const ThreadIdInfo& thread_id_info, ValueVector3 tile_dimensions,
586       const IrEmitterUnnested::EmitElementFunction& emit_elem_function);
587 
588   // Emits code to process a tensor element in a tile for the given kLoop
589   // fusion HLO containing parameters that are 0-2-1 transpose of its outputs.
590   //
591   // y_loc: The y coordinate within a tile.
592   // x_loc: The x coordinate within a tile.
593   void EmitTileElementForTranspose(
594       const ThreadIdInfo& thread_id_info, mlir::lmhlo::FusionOp fusion,
595       absl::Span<const llvm_ir::IrArray> operand_arrays,
596       absl::Span<const llvm_ir::IrArray> output_arrays,
597       const llvm_ir::IrArray::Index& index, const TilingScheme& tiling_scheme,
598       llvm::Value* y_loc, llvm::Value* x_loc,
599       absl::Span<llvm::Value* const> param_shmem_buffers);
600 
601   // Creates accumulator alloca's, populates them with initial values, generates
602   // __shared__ caches and returns the populated object.
603   ReductionCodegenState GenerateReductionCodegenState(
604       mlir::lmhlo::FusionOp fusion, const ReductionCodegenInfo& reduction_info,
605       absl::Span<const HloReduceInstruction* const> reduce_instr_index_group,
606       FusedIrEmitter& fused_emitter);
607 
608   // Wraps up the code generation for a tile block of a reduction kernel:
609   // write the calculated output into the output tensor.
610   void EmitReductionOutput(
611       llvm::Type* index_ty, mlir::lmhlo::FusionOp fusion,
612       absl::Span<const HloReduceInstruction* const> reduce_instr_index_group,
613       const ReductionOutputMap& result_ir_arrays,
614       const ReductionCodegenState& reduction_codegen_state,
615       const TilingKernelInfo& tiling_kernel_info);
616 
617   // Returns the address to write the reduction output to.
618   llvm::Value* GetOutputAddressForReduction(
619       int partial_result_idx, llvm::Type* index_ty,
620       const ReductionCodegenState& reduction_codegen_state,
621       const TilingKernelInfo& tiling_kernel_info,
622       const IrEmitterUnnested::ReductionOutputMap& output_arrays,
623       const HloReduceInstruction* reduction, int output_idx);
624 
625   // `current_output`: the value the tile has calculated.
626   // `output_address`: address where the output value has to be written.
627   void EmitReductionOutputForRowReduction(
628       const TilingKernelInfo& tiling_kernel_info,
629       const ReductionCodegenState& reduction_codegen_state,
630       llvm::Type* index_ty, const ReductionOutputMap& output_arrays,
631       const HloReduceInstruction* reduction, int partial_result_idx);
632 
633   // Same arguments as EmitReductionOutputForRowReduction.
634   void EmitReductionOutputForColumnReduction(
635       const TilingKernelInfo& tiling_kernel_info,
636       const ReductionCodegenState& reduction_codegen_state,
637       llvm::Type* index_ty, const ReductionOutputMap& output_arrays,
638       const HloReduceInstruction* reduction, int partial_result_idx);
639 
640   // Emits code for reductions in the output_instructions.
641   Status EmitIRForReduction(mlir::lmhlo::FusionOp fusion,
642                             absl::Span<HloInstruction* const> instr_index_group,
643                             FusedIrEmitter& fused_emitter,
644                             const ReductionOutputMap& result_ir_arrays,
645                             const ReductionCodegenInfo& reduction_info,
646                             const Shape& input_shape);
647 
648   // Generate a single element of the tile (update the accumulator state) for a
649   // given reducer of index `i`.
650   void GenerateElementForReducer(
651       const HloReduceInstruction* reduction, llvm::Value* partial_result_index,
652       const ReductionCodegenState& codegen_state,
653       const llvm_ir::IrArray::Index& index_without_linear,
654       const llvm_ir::IrArray::Index& input_index, int num_partial_results,
655       const ReductionOutputMap& result_ir_arrays);
656 
657   // Emits shuffle-down reduction for the `partial_result_address` using the
658   // reduction computation `reducer`, writes output into
659   // `partial_result_address`.
660   //
661   // Multiple partial_result_address inputs happen when doing variadic
662   // reduction: each one should get the output value.
663   void EmitFullWarpShuffleDownLoopForReduce(
664       const HloComputation* reducer,
665       absl::Span<std::pair<llvm::Value* const, llvm::Type* const>>
666           partial_result_addresses,
667       int threads_per_block, int num_results_per_warp = 1);
668 
669   // Allocates a shared tile of given dimensions, applying scaling specified in
670   // tilng_scheme as a major-most dimension to avoid collisions.
671   llvm::GlobalVariable* AllocateShared(
672       const TilingScheme& tiling_scheme, llvm::Type* element_type,
673       absl::Span<int64_t const> dimensions_major_to_minor,
674       absl::string_view buffer_name = "");
675 
676   StatusOr<std::unique_ptr<Thunk>> BuildKernelThunkImpl(
677       absl::string_view name, Thunk::ThunkInfo thunk_info,
678       absl::Span<const BufferSlice> slices,
679       std::vector<llvm_ir::IrArray>* ir_arrays,
680       const LaunchDimensions& launch_dimensions);
681 
682   StatusOr<std::unique_ptr<Thunk>> BuildKernelThunk(
683       mlir::Operation* op, mlir::ValueRange operands,
684       Thunk::ThunkInfo thunk_info, std::vector<llvm_ir::IrArray>* ir_arrays,
685       const LaunchDimensions& launch_dimensions);
686 
687   StatusOr<std::unique_ptr<Thunk>> BuildKernelThunk(
688       mlir::Operation* op, Thunk::ThunkInfo thunk_info,
689       std::vector<llvm_ir::IrArray>* ir_arrays,
690       const LaunchDimensions& launch_dimensions);
691 
692   // Returns a thunk that, given a reduce or select-and-scatter op,
693   // initializes its memory to the appropriate initial value.
694   std::unique_ptr<Thunk> BuildConstantInitializerThunk(
695       absl::Span<const uint8_t> init_value, const BufferAllocation::Slice& dest,
696       const Shape& output_shape);
697 
698   StatusOr<std::unique_ptr<Thunk>> TryBuildConstantInitializerThunk(
699       mlir::Value init_value, mlir::Value dest);
700 
701   StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunk(mlir::Operation* op,
702                                                          mlir::Value init_value,
703                                                          mlir::Value dest);
704   StatusOr<std::unique_ptr<Thunk>> BuildFusedInitializerThunk(
705       mlir::lmhlo::FusionOp fusion, int output_index);
706 
707   // Returns a WhileThunk that invokes thunk sequences for 'condition' and
708   // 'body' sub-computations of while instruction 'hlo'.
709   StatusOr<std::unique_ptr<Thunk>> BuildWhileThunk(
710       mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info);
711 
712   // Returns a ForThunk which executes 'loop_limit' invocations of a thunk
713   // sequence from the 'body' sub-computation of the while instruction 'hlo'.
714   StatusOr<std::unique_ptr<Thunk>> BuildForThunk(
715       mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info,
716       const int64_t loop_limit);
717 
718   // Returns a ConditionalThunk which executes the thunk sequence for the
719   // 'branch_computation' corresponding to the predicate/branch_index of the
720   // given conditional instruction.
721   StatusOr<std::unique_ptr<Thunk>> BuildConditionalThunk(
722       const HloInstruction* conditional);
723 
724   // Emits the LLVM values for thread_id, thread_id.x, thread_id.y and lane
725   // id.
726   //
727   // Returns a struct containting these values.
728   //
729   // In the presence of thread scaling in tiling scheme may return early if the
730   // combination of thread_id/block_id does not correspond to a real block.
731   // Assumes the current function returns void.
732   StatusOr<ThreadIdInfo> EmitThreadIdInfo(const TilingScheme& tiling_scheme,
733                                           llvm::Type* index_ty);
734   // Emit __syncthreads(), synchronization barrier for all threads in a block.
735   llvm::CallInst* EmitSyncThreads();
736 
737   // Emits current thread id with the given type.
738   //
739   // Sets the return value range to [0, threads_per_block).
740   llvm::Value* EmitThreadId(int64_t threads_per_block, llvm::Type* index_ty);
741 
742   // Emits current block id.
743   llvm::Value* EmitBlockId(int32_t num_blocks, llvm::Type* index_ty);
744 
745   // Prints a given format string with the given arguments, prefixed with
746   // thread id and block id, and postfixed with a newline.
747   //
748   // `thread_id_filter` and `block_id_filter`: if provided, restrict printing
749   // to only given thread and/or block id.
750   void EmitPrintfWithThreadId(
751       absl::string_view fmt, absl::Span<llvm::Value* const> arguments,
752       std::optional<int64_t> thread_id_filter = std::nullopt,
753       std::optional<int64_t> block_id_filter = std::nullopt);
754 
755   StatusOr<HloComputation*> GetOrCreateSubComputationFromRegion(
756       mlir::Region* region, bool is_fusion);
757 
758   // Returns the last generated thunk.
LastThunk()759   Thunk* LastThunk() const { return thunk_sequence_.back().get(); }
760 
761   Status AssertNonDeterminismIsOkay(const std::string& op_name);
762 
763   // The thunk sequence this IrEmitter generates for the input computation.
764   ThunkSequence thunk_sequence_;
765 
766   // Maps all-reduce-start ops to their thunk so done can access the thunk.
767   absl::flat_hash_map<mlir::Operation*, NcclAllReduceStartThunk*>
768       all_reduce_start_thunks_;
769 
770   // Begin optional members for XLA HLO -> LMHLO:
771   absl::flat_hash_map<const mlir::Region*, std::unique_ptr<HloModule>>
772       scratch_nested_computations_;
773   // End optional members for XLA HLO -> LMHLO.
774 
775   // __shared__ memory uses a different address space, so we cast it to
776   // global address space before writing or reading.
777   llvm::Value* CastSharedToGlobal(llvm::Value* input, llvm::Type* element_type,
778                                   llvm::Twine name = "");
779 
780   // Returns the ShapedSlices for the given operands.
781   StatusOr<std::vector<ShapedSlice>> GetShapedSlices(
782       mlir::Operation::operand_range operands);
783 
784   // Returns the buffer allocation Slice for the given operands.
785   StatusOr<std::vector<BufferAllocation::Slice>> GetSlices(
786       mlir::Operation::operand_range operands);
787 };
788 
789 }  // namespace gpu
790 }  // namespace xla
791 
792 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
793