1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ 18 19 #include <array> 20 #include <functional> 21 #include <string> 22 23 #include "absl/container/inlined_vector.h" 24 #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h" 25 #include "tensorflow/compiler/xla/service/custom_call_status.h" 26 #include "tensorflow/compiler/xla/service/gpu/custom_call_thunk.h" 27 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" 28 #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" 29 #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h" 30 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" 31 #include "tensorflow/compiler/xla/service/gpu/thunk.h" 32 #include "tensorflow/compiler/xla/service/hlo_computation.h" 33 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 34 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 35 36 namespace xla { 37 namespace gpu { 38 39 struct BufferSlice { 40 // The root buffer to look at. 41 BufferAllocation::Slice buffer_slice; 42 43 // The global constant name of the buffer, if it's a constant. 44 std::string constant_name; 45 46 // The buffer is modified by the kernel. 47 bool written = false; 48 49 Shape shape; 50 }; 51 52 // Emits LLVM IR for an "unnested computation". 53 // 54 // An unnested computation is an HloComputation which you run by executing one 55 // or more kernels for each HloInstruction it contains. Examples of unnested 56 // computations: 57 // 58 // - An HloModule's root computation, 59 // - The body of an HLO while loop, 60 // - The true/false computation of an HLO conditional. 61 // 62 // Note the opportunity for confusion -- the while loop's computation is nested 63 // within the root computation, but it's emitted using IrEmitterUnnested! Don't 64 // think about it too hard. 65 // 66 // Examples of things that are not unnested computations: 67 // 68 // - The body of a fusion node. IrEmitterUnnested emits the relevant code 69 // within a kernel function using FusedIrEmitter. (FusedIrEmitter is not 70 // really an IrEmitter, but is more an "IR generator generator".) 71 // 72 class IrEmitterUnnested : public IrEmitter { 73 public: 74 // Contains threading information. Note that for performance we might apply 75 // thread id "scaling" where the physical thread id (to achieve good SM 76 // occupancy) will differ from logical thread id. This struct contains 77 // logical thread ids, along with meta-information about the scaling applied. 78 struct ThreadIdInfo { ThreadIdInfoThreadIdInfo79 ThreadIdInfo(llvm::Value* thread_id, llvm::Value* thread_id_x, 80 llvm::Value* thread_id_y, llvm::Value* lane_id, 81 llvm::Value* block_id, llvm::Value* scaling) 82 : thread_id(thread_id), 83 thread_id_x(thread_id_x), 84 thread_id_y(thread_id_y), 85 lane_id(lane_id), 86 block_id(block_id), 87 scaling(scaling) {} 88 89 llvm::Value* thread_id; 90 91 // X-coordinate calculated from thread id: `thread_id % num_threads_x` 92 llvm::Value* thread_id_x; 93 94 // Y-coordinate calculated from thread id: `thread_id / num_threads_x` 95 llvm::Value* thread_id_y; 96 97 // Lane id: `thread_id % WarpSize` 98 llvm::Value* lane_id; 99 100 // Block id. 101 llvm::Value* block_id; 102 103 // Emits GEP into a shared memory, taking virtual thread scaling into 104 // account. Automatically inserts the first zero required by LLVM GEP. 105 // Defined on ThreadIdInfo to keep `scaling` private. 106 // 107 // Same semantics as CreateInBoundsGEP. 108 llvm::Value* GEPIntoSharedMemory( 109 llvm::IRBuilder<>* b, llvm::GlobalVariable* shared, 110 absl::Span<llvm::Value* const> idx_major_to_minor, 111 const llvm::Twine& name = "") const; 112 113 // Calculuate the pointee type of the llvm::Value returned by 114 // GEPIntoSharedMemory 115 llvm::Type* GEPIntoSharedMemoryType( 116 llvm::GlobalVariable* shared, 117 absl::Span<llvm::Value* const> idx_major_to_minor) const; 118 119 private: 120 llvm::Value* scaling; 121 }; 122 platform_name()123 absl::string_view platform_name() const { 124 return ir_emitter_context_->platform_name(); 125 } 126 127 using ValueVector3 = std::array<llvm::Value*, 3>; 128 129 // A function object to generate code to process one element in a tile. 130 // 131 // index: the index for the first output element of the current thread. 132 // y_loc: The y coordinate within a tile. 133 // x_loc: The x coordinate within a tile. 134 // x_iter_num: When a thread process N elements in the X dimension, x_iter_num 135 // has a value of 0..N-1 to identify the element being process. 136 using EmitElementFunction = std::function<void( 137 const ThreadIdInfo& thread_id_info, const llvm_ir::IrArray::Index& index, 138 llvm::Value* y_loc, llvm::Value* x_loc, llvm::Value* x_iter_num)>; 139 140 using ConstantGenerator = std::function<llvm::Value*(int64_t)>; 141 142 // A function to generate the code to emit the entire tile. 143 // 144 // index: Absolute coordinate of the start of the tile in input. 145 // tile_dimensions: Size of the tile 146 using TileElementGenerator = std::function<void( 147 const ThreadIdInfo& thread_id_info, const llvm_ir::IrArray::Index& index, 148 ValueVector3 tile_dimensions)>; 149 150 // Fusion root -> array of indexes, one per reduction output. 151 using ReductionOutputMap = 152 ConstHloInstructionMap<absl::Span<llvm_ir::IrArray const>>; 153 154 using ExtraOutputGensMap = ConstHloInstructionMap<llvm_ir::ElementGenerator>; 155 156 IrEmitterUnnested(const IrEmitterUnnested&) = delete; 157 IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete; 158 159 static StatusOr<std::unique_ptr<IrEmitterUnnested>> Create( 160 const HloModuleConfig& hlo_module_config, 161 IrEmitterContext* ir_emitter_context); 162 163 // Transfers the ownship of thunk_sequence_ out. ConsumeThunkSequence()164 std::unique_ptr<ThunkSequence> ConsumeThunkSequence() { 165 return std::make_unique<ThunkSequence>(std::move(thunk_sequence_)); 166 } 167 168 // Emits code for the given LMHLO region. 169 // 170 // Also populates related information to 'ir_emitter_context_' for 171 // large-constant initializations. Large constants don't get initializers in 172 // the generated code and so must be initialized by XLA. The value of these 173 // constants will be stored in 'content'. Constants with initializers in the 174 // generated code will have empty 'content'. 175 Status EmitLmhloRegion(mlir::Region* region); 176 177 static void GetDependentDialects(mlir::DialectRegistry& registry); 178 179 private: 180 IrEmitterUnnested(const HloModuleConfig& hlo_module_config, 181 IrEmitterContext* ir_emitter_context); 182 183 // IrEmitterUnnested handles the following instructions differently from 184 // IrEmitter. It also mixes in some special handling for custom kernels 185 // via the ThunkEmitter. 186 Status EmitConstant(mlir::Operation* op); 187 188 Status EmitCopy(mlir::Operation* op); 189 190 Status EmitConditional(mlir::Operation* op); 191 Status EmitConvolutionThunk(mlir::Operation* op); 192 Status EmitGemmThunk(mlir::Operation* op); 193 #if GOOGLE_CUDA 194 Status EmitCublasLtMatmulThunk(mlir::Operation* op); 195 #endif // GOOGLE_CUDA 196 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 197 Status EmitCholeskyThunk(mlir::Operation* op); 198 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 199 Status EmitCustomCallThunk(mlir::Operation* op); 200 Status EmitFftThunk(mlir::Operation* op); 201 Status EmitFusion(mlir::Operation* op); 202 Status EmitLaunchFunc(mlir::Operation* op); 203 Status EmitLoopFusion(mlir::Operation* op); 204 Status EmitReduce(mlir::Operation* op); 205 Status EmitSelectAndScatter(mlir::Operation* op); 206 Status EmitWhile(mlir::Operation* op); 207 Status EmitInfeed(mlir::Operation* op); 208 Status EmitOutfeed(mlir::Operation* op); 209 Status EmitRngGetAndUpdateState(mlir::Operation* op); 210 Status EmitScatter(mlir::Operation* op); 211 Status EmitSort(mlir::Operation* op); 212 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 213 Status EmitTriangularSolveCustomCall(mlir::Operation* op); 214 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 215 216 template <typename NcclThunkType, typename OpTy> 217 Status EmitNcclThunk(mlir::Operation* op); 218 Status EmitAllReduceDone(mlir::Operation* op); 219 220 template <typename ThunkType, typename OpT> 221 Status EmitReplicaOrPartitionId(mlir::Operation* op); 222 223 Status EmitCollectivePermute(mlir::Operation* op); 224 225 Status EmitOp(mlir::Operation* op); 226 227 static Thunk::ThunkInfo GetThunkInfo(mlir::Operation* op); 228 229 Status EmitTargetElementLoop( 230 const HloInstruction& hlo, 231 const llvm_ir::ElementGenerator& body_emitter) override; 232 233 // Add a owning Thunk object to the thunk sequence. AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk)234 void AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk) { 235 thunk_sequence_.emplace_back(std::move(thunk)); 236 } 237 238 // Load data from potentially unaligned address. If address is offset by 239 // `alignment_bytes`, data is read in the unit of `alignment_bytes` to avoid 240 // memory read misalignment in CUDA; otherwise, the entire data are loaded 241 // from the given memory address. 242 // 243 // address: the memory address to load data from. 244 // data_type: the type of data to load. 245 // alignment_bytes: the number of bytes required to align. The number of 246 // bytes of the data_type must be divisible by alignment_bytes. 247 llvm::Value* CreateLoad(llvm::Value* address, llvm::Type* data_type, 248 int alignment_bytes); 249 250 // Store data at a potentially unaligned address. If the address is offset by 251 // `alignment_bytes`, data is stored in the unit of `alignment_bytes` to avoid 252 // memory write misalignment in CUDA; otherwise, the entire data is stored at 253 // the given memory address. 254 // 255 // data: the data to be stored. 256 // address: the memory address to store data. 257 // alignment_bytes: the number of bytes required to align. The number of 258 // bytes of the data_type must be divisible by alignment_bytes. 259 void CreateStore(llvm::Value* data, llvm::Value* address, 260 int alignment_bytes); 261 262 // Input = {static array, dynamic_dim0, dynamic_dim1} 263 // Output = {dynamic array(with dynamic dimension meta data at the end)} 264 // For a tensor with static dimension [2][<=5] and dynamic dimension [2][3] 265 // (`_` stands for padding) 266 // Input = {{1,2,3,_,_,4,5,6_,_}, 2, 3} 267 // Output = {{1,2,3,4,5,6,_,_,_,_,2,3}} 268 269 // pseudo code for padToStatic on a 2d array 270 // ``` 271 // void padToStatic(int** input, int** output, int threads_per_block, 272 // int meta_data_offset, int max_num_element, 273 // int static_dim0_size, int static_dim1_size) { 274 // int* source_array = input[0]; 275 // int* dest_array = output[0]; 276 277 // // extract the dynamic dimension from the source array's metadata 278 // int* dyn_dim0_size = source_array + meta_data_offset; 279 // int* dyn_dim1_size = source_array + meta_data_offset + sizeof(int); 280 281 // // only one thread need to store the dynamic index 282 // int thread_id = GetThreadId(); 283 // int block_id = GetBlockId(); 284 // if (thread_id == 0 && block_id == 0) { 285 // *output[1] = *dyn_dim0_size; 286 // *output[2] = *dyn_dim1_size; 287 // } 288 289 // int dyn_element_total = 1; 290 // dyn_element_total *= *dyn_dim0_size; 291 // dyn_element_total *= *dyn_dim1_size; 292 // linear_index = block_id * threads_per_block + thread_id; 293 // if (linear_index < max_num_element) { 294 // Index static_index = 295 // delinerized(linerized_index, static_dim0_size, static_dim1_size); 296 // if (linerized_index < dyn_element_total) { 297 // Index dyn_index = 298 // delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size); 299 // dest_array[dyn_index.dim0][dyn_index.dim1] = 300 // source_array[static_index.dim0][static_index.dim1]; 301 // } 302 // } 303 // return; 304 // } 305 // ``` 306 Status EmitPadToStatic(mlir::Operation* op); 307 308 // Input = {dynamic array(with dynamic dimension meta data at the end)} 309 // Output = {static array, dynamic_dim0, dynamic_dim1} 310 // For a tensor with static dimension [2][<=5] and dynamic dimension [2][3] 311 // (`_` stands for padding) 312 // Input = {{1,2,3,4,5,6,_,_,_,_,2,3}} 313 // Output = {{1,2,3,_,_,4,5,6_,_}, 2, 3} 314 315 // pseudo code for sliceToDynamic on a 2d array 316 // ``` 317 // void sliceToDynamic(int** input, int** output, int threads_per_block, 318 // int meta_data_offset, int max_num_element, 319 // int static_dim0_size, int static_dim1_size) { 320 // int* source_array = input[0]; 321 // int* dest_array = output[0]; 322 323 // // calculate the location where metadata needs to be inserted 324 // int* dyn_dim0_size = dest_array + meta_data_offset; 325 // int* dyn_dim1_size = dest_array + meta_data_offset + sizeof(int); 326 327 // // only one thread need to store the dynamic index 328 // int thread_id = GetThreadId(); 329 // int block_id = GetBlockId(); 330 // if (thread_id == 0 && block_id == 0) { 331 // *dyn_dim0_size = *output[1]; 332 // *dyn_dim1_size = *output[2]; 333 // } 334 335 // int dyn_element_total = 1; 336 // dyn_element_total *= *dyn_dim0_size; 337 // dyn_element_total *= *dyn_dim1_size; 338 // linear_index = block_id * threads_per_block + thread_id; 339 // if (linear_index < max_num_element) { 340 // Index static_index = 341 // delinerized(linerized_index, static_dim0_size, static_dim1_size); 342 // if (linerized_index < dyn_element_total) { 343 // Index dyn_index = 344 // delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size); 345 // dest_array[static_index.dim0][static_index.dim1] = 346 // source_array[dyn_index.dim0][dyn_index.dim1]; 347 // } 348 // } 349 // return; 350 // } 351 // ``` 352 Status EmitSliceToDynamic(mlir::Operation* op); 353 354 StatusOr<BufferAllocation::Slice> GetAllocationSlice( 355 mlir::Value v, std::string* constant_name = nullptr); 356 ByteSizeOf(const Shape & shape)357 int64_t ByteSizeOf(const Shape& shape) const { 358 return llvm_ir::ByteSizeOf( 359 shape, ir_emitter_context_->llvm_module()->getDataLayout()); 360 } 361 362 // Builds the prototype of the IR kernel for `inst` and adds it to the module. 363 // This kernel takes as arguments pointers to the given buffer allocations. 364 llvm::Function* BuildKernelPrototype( 365 absl::string_view name, absl::Span<const BufferAllocation* const> args); 366 367 // Helper for writing extra outputs from inside a reduce kernel. 368 Status EmitExtraOutputsForReduce(const ReductionOutputMap& result_ir_arrays, 369 const llvm_ir::IrArray::Index& index, 370 const ReductionCodegenInfo& reduction_info, 371 const ExtraOutputGensMap& extra_output_gens); 372 373 // Generates code for reduction to contiguous dimensions. 374 // 375 // Row reduction uses the following algorithm described in CUDA-like 376 // pseudocode: 377 // 378 // ``` 379 // __global__ void reduce(int num_rows, float *in, float out) { 380 // __shared__ float[32] cache; 381 // int offset = blockDim.x * blockIdx.x + threadIdx.x; 382 // if (offset >= num_rows) return; 383 // int tile_bound = std::min(offset + kTileSizeX, num_rows); 384 // float accum = 0; 385 // for (int i=offset; i<num_rows; i+= blockDim.x) { 386 // accum += in[i]; 387 // } 388 // accum = warp_reduce(accum); 389 // if (threadIdx.x % WarpSize == 0) { 390 // cache[threadIdx.x / WarpSize] = accum; 391 // } 392 // __syncthreads(); 393 // if (threadIdx.x / WarpSize == 0) { 394 // bool warp_exists = threadIdx.x < (blockDim.x / WarpSize); 395 // float block_accum = warp_exists ? cache[threadIdx.x % WarpSize] : 0; 396 // block_accum = warp_reduce(accum); 397 // if (threadIdx.x == 0) { 398 // out += block_accum; 399 // } 400 // } 401 // } 402 // ``` 403 // 404 // Column reduction uses the following algorithm: 405 // 406 // ``` 407 // void reduce(float** in, float* out) { 408 // __shared__ float[32][33] cache; 409 // int thread_id = GetThreadId(); 410 // int block_id = GetBlockId(); 411 // int tile_size = 128; 412 // 413 // float accum = 0; 414 // for (int i=0; i<tile_size; i++) { 415 // accum += in[thread_id.y * tile_size + i][block_id * 32 + thread_id.x]; 416 // } 417 // cache[thread_id.x][thread_id.y] = accum; 418 // 419 // __syncthreads(); 420 // accum = cache[thread_id.y][thread_id.x]; 421 // accum = warp_reduce(accum); // Sum all the values of `accum` in the same 422 // // warp. 423 // 424 // if (thread_id.y % 32 == 0) { 425 // out[block_id * 32 + thread_id.x] = accum; 426 // } 427 // } 428 // ``` 429 // 430 // Moreover, a heuristic is implemented to divide the reduce instructions 431 // into groups for parallelization (see `DivideOutputInstructionsIntoGroups` 432 // for details about the heuristic.) Reduce instructions in the same group 433 // will run sequentially while different groups will run in parallel. 434 // 435 // we use raw block_id_y to select the reduce groups for execution without 436 // complicating the index calculation in the code generation of the reduce 437 // instructions. In other words, a block_id_y is assigned to a group and so 438 // different groups can be run in parallel. 439 Status EmitUnnestedReduction(mlir::lmhlo::FusionOp fusion); 440 441 // Computes the KernelMappingScheme for the reduce HLO and indicates whether 442 // the reduction is a row reduction. For an un-fused reduce op, unnested_hlo 443 // and first_reduce are the same instruction. For a kInput fusion, 444 // unnested_hlo is the fusion instruction while first_reduce is the first 445 // reduce op. 446 StatusOr<ReductionCodegenInfo> ComputeReductionCodegenInfo( 447 mlir::lmhlo::FusionOp fusion, mlir::mhlo::ReduceOp first_reduce); 448 449 // Generates code for input-fusible slices. 450 // 451 // Prerequisite: ROOT is either a slice or a tuple of slices. The input shapes 452 // of all ROOT slices need to be the same while their output shapes can be 453 // different. On the other hand, the input ranges of slices can be 454 // overlapping. Further generalization/specialization when the needs are seen 455 // in the future. 456 Status EmitInputFusibleNonStridedSlices(mlir::Operation* op); 457 458 Status EmitElementForInputFusibleSlices( 459 const HloComputation* fused_computation, 460 absl::Span<const llvm_ir::IrArray> ir_arrays, 461 const llvm_ir::IrArray::Index& index); 462 463 // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in 464 // the process. Scatter indices are taken from `scatter_indices_gen`, updates 465 // from `updates_gen`. The output buffer is expected to have the operand 466 // values in it already. If unique_indices is false, we will use an atomic 467 // update. Using true for unique_indices behaves properly only when it is 468 // guaranteed that the indices to be updated do not overlap. The caller is 469 // responsible for ensuring this is the case. 470 Status EmitScatter(Thunk* thunk, mlir::lmhlo::ScatterOp scatter, 471 const LaunchDimensions& launch_dimensions, 472 const llvm_ir::IrArray& output, 473 const llvm_ir::ElementGenerator& scatter_indices_gen, 474 const llvm_ir::ElementGenerator& updates_gen, 475 std::function<llvm::Type*(int64_t)> get_index_type); 476 477 // Structure describing a scatter operation for IR emission. 478 // TODO(jurahul): Migrate element generators to use MLIR. 479 // Migrate update_computation to be an MLIR Region. 480 struct ScatterDescriptor { 481 std::string name; 482 Shape operand_shape; 483 Shape scatter_indices_shape; 484 Shape updates_shape; 485 mlir::mhlo::ScatterDimensionNumbersAttr dim_numbers; 486 bool unique_indices; 487 const HloComputation* update_computation; 488 llvm_ir::IrArray output; 489 llvm_ir::ElementGenerator scatter_indices_gen; 490 llvm_ir::ElementGenerator updates_gen; 491 std::function<llvm::Type*(int64_t)> get_index_type; 492 }; 493 494 // Emits code for an in-place scatter using the provided scatter operation 495 // description. 496 Status EmitScatter(const ScatterDescriptor& desc, Thunk* thunk, 497 const LaunchDimensions& launch_dimensions); 498 499 // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel 500 // for the hlo instruction. 501 Status Emit021Transpose(TransposeDimsAndParams descr, 502 mlir::lmhlo::FusionOp fusion); 503 504 // Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose 505 // algorithm to improve the memory access patterns for the input parameters 506 // with a shape that is a 0-2-1 transpose of the output tensor shape. The 507 // caller is responsible for making sure that it is safe to apply the shared 508 // memory transpose on the input parameters. 509 // 510 // 511 // For the purpose of tiling, the output tensors have a logical shape of three 512 // components 0-2-1 while the relevant input parameters have a logical shape 513 // of three components 0-1-2 in the order major to minor. The x- and y- 514 // dimensions of the tensors are tiled in square tiles with an edge length 515 // `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads 516 // transposes one tile: each thread copies kTileSize/kNumRows elements from 517 // the input to a shared memory tile, then the otherwise "regular HLO kernel" 518 // reads from the shared memory instead of the original input. 519 // 520 // This is similar to the following CUDA algorithm in TensorFlow: 521 // https://goo.gl/MStRV6. 522 // 523 // `kTileSize` should usually be same as warp size. We currently choose 32 for 524 // `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`. 525 // 526 // TODO(b/33320379): Here each block transposes 1 tile. It may be more 527 // efficient to launch fewer blocks so each transposes many tiles. 528 Status EmitHlo021Tile(mlir::lmhlo::FusionOp fusion, 529 absl::Span<const llvm_ir::IrArray> operand_arrays, 530 absl::Span<const llvm_ir::IrArray> output_arrays, 531 TransposeDimsAndParams descr, 532 const TilingScheme& tiling_scheme, 533 const LaunchDimensions& launch_dimensions); 534 535 Status EmitScatter(mlir::lmhlo::FusionOp fusion_op, 536 const HloComputation* fused_computation); 537 538 Status EmitDynamicUpdateSlice(mlir::lmhlo::FusionOp fusion_op, 539 const HloComputation* fused_computation); 540 541 struct TilingKernelInfo { 542 // Tiling bounds. 543 ValueVector3 output_tile_bounds; 544 545 // Starting tile, as calculated from block id only. 546 llvm_ir::IrArray::Index tile_origin; 547 548 // Thread meta-info. 549 ThreadIdInfo thread_id_info; 550 }; 551 552 // Emits a kernel for the hlo instruction using the given kernel mapping 553 // scheme. 554 StatusOr<TilingKernelInfo> EmitTilingKernel( 555 const TilingScheme& tiling_scheme, llvm::Type* index_ty, 556 const TileElementGenerator& tile_element_generator); 557 558 // Emits code to iterate through a 2-dimensional tile with a given tile 559 // dimensions and given strides, and call the callback at each iteration., 560 // 561 // thread_id_y` and `thread_id_x` are the intra-tile coordinates for 562 // the first element to process, and `index` is the index for the origin of 563 // the tile. Emits bounds check to ensure that each processed element 564 // is within the boundary defined by `tile_dimensions`. 565 // 566 // Rough pseudocode: 567 // 568 // Given: tile_dimensions, x_offset, y_offset 569 // 570 // for (y = 0; y < tile_dimensions[Y]; y += num_threads_y) { 571 // for (x = 0; x < tile_dimensions[X]; x++) { 572 // 573 // y_pos = y_offset + y 574 // x_pos = x_offset + x * stride 575 // 576 // if (x_loc < tile_width) { 577 // emit_elem_function(y_offset + y, x_loc); 578 // } 579 // } 580 // } 581 // 582 void EmitTile( 583 const TilingScheme& tiling_scheme, 584 const llvm_ir::IrArray::Index& tile_origin_index, 585 const ThreadIdInfo& thread_id_info, ValueVector3 tile_dimensions, 586 const IrEmitterUnnested::EmitElementFunction& emit_elem_function); 587 588 // Emits code to process a tensor element in a tile for the given kLoop 589 // fusion HLO containing parameters that are 0-2-1 transpose of its outputs. 590 // 591 // y_loc: The y coordinate within a tile. 592 // x_loc: The x coordinate within a tile. 593 void EmitTileElementForTranspose( 594 const ThreadIdInfo& thread_id_info, mlir::lmhlo::FusionOp fusion, 595 absl::Span<const llvm_ir::IrArray> operand_arrays, 596 absl::Span<const llvm_ir::IrArray> output_arrays, 597 const llvm_ir::IrArray::Index& index, const TilingScheme& tiling_scheme, 598 llvm::Value* y_loc, llvm::Value* x_loc, 599 absl::Span<llvm::Value* const> param_shmem_buffers); 600 601 // Creates accumulator alloca's, populates them with initial values, generates 602 // __shared__ caches and returns the populated object. 603 ReductionCodegenState GenerateReductionCodegenState( 604 mlir::lmhlo::FusionOp fusion, const ReductionCodegenInfo& reduction_info, 605 absl::Span<const HloReduceInstruction* const> reduce_instr_index_group, 606 FusedIrEmitter& fused_emitter); 607 608 // Wraps up the code generation for a tile block of a reduction kernel: 609 // write the calculated output into the output tensor. 610 void EmitReductionOutput( 611 llvm::Type* index_ty, mlir::lmhlo::FusionOp fusion, 612 absl::Span<const HloReduceInstruction* const> reduce_instr_index_group, 613 const ReductionOutputMap& result_ir_arrays, 614 const ReductionCodegenState& reduction_codegen_state, 615 const TilingKernelInfo& tiling_kernel_info); 616 617 // Returns the address to write the reduction output to. 618 llvm::Value* GetOutputAddressForReduction( 619 int partial_result_idx, llvm::Type* index_ty, 620 const ReductionCodegenState& reduction_codegen_state, 621 const TilingKernelInfo& tiling_kernel_info, 622 const IrEmitterUnnested::ReductionOutputMap& output_arrays, 623 const HloReduceInstruction* reduction, int output_idx); 624 625 // `current_output`: the value the tile has calculated. 626 // `output_address`: address where the output value has to be written. 627 void EmitReductionOutputForRowReduction( 628 const TilingKernelInfo& tiling_kernel_info, 629 const ReductionCodegenState& reduction_codegen_state, 630 llvm::Type* index_ty, const ReductionOutputMap& output_arrays, 631 const HloReduceInstruction* reduction, int partial_result_idx); 632 633 // Same arguments as EmitReductionOutputForRowReduction. 634 void EmitReductionOutputForColumnReduction( 635 const TilingKernelInfo& tiling_kernel_info, 636 const ReductionCodegenState& reduction_codegen_state, 637 llvm::Type* index_ty, const ReductionOutputMap& output_arrays, 638 const HloReduceInstruction* reduction, int partial_result_idx); 639 640 // Emits code for reductions in the output_instructions. 641 Status EmitIRForReduction(mlir::lmhlo::FusionOp fusion, 642 absl::Span<HloInstruction* const> instr_index_group, 643 FusedIrEmitter& fused_emitter, 644 const ReductionOutputMap& result_ir_arrays, 645 const ReductionCodegenInfo& reduction_info, 646 const Shape& input_shape); 647 648 // Generate a single element of the tile (update the accumulator state) for a 649 // given reducer of index `i`. 650 void GenerateElementForReducer( 651 const HloReduceInstruction* reduction, llvm::Value* partial_result_index, 652 const ReductionCodegenState& codegen_state, 653 const llvm_ir::IrArray::Index& index_without_linear, 654 const llvm_ir::IrArray::Index& input_index, int num_partial_results, 655 const ReductionOutputMap& result_ir_arrays); 656 657 // Emits shuffle-down reduction for the `partial_result_address` using the 658 // reduction computation `reducer`, writes output into 659 // `partial_result_address`. 660 // 661 // Multiple partial_result_address inputs happen when doing variadic 662 // reduction: each one should get the output value. 663 void EmitFullWarpShuffleDownLoopForReduce( 664 const HloComputation* reducer, 665 absl::Span<std::pair<llvm::Value* const, llvm::Type* const>> 666 partial_result_addresses, 667 int threads_per_block, int num_results_per_warp = 1); 668 669 // Allocates a shared tile of given dimensions, applying scaling specified in 670 // tilng_scheme as a major-most dimension to avoid collisions. 671 llvm::GlobalVariable* AllocateShared( 672 const TilingScheme& tiling_scheme, llvm::Type* element_type, 673 absl::Span<int64_t const> dimensions_major_to_minor, 674 absl::string_view buffer_name = ""); 675 676 StatusOr<std::unique_ptr<Thunk>> BuildKernelThunkImpl( 677 absl::string_view name, Thunk::ThunkInfo thunk_info, 678 absl::Span<const BufferSlice> slices, 679 std::vector<llvm_ir::IrArray>* ir_arrays, 680 const LaunchDimensions& launch_dimensions); 681 682 StatusOr<std::unique_ptr<Thunk>> BuildKernelThunk( 683 mlir::Operation* op, mlir::ValueRange operands, 684 Thunk::ThunkInfo thunk_info, std::vector<llvm_ir::IrArray>* ir_arrays, 685 const LaunchDimensions& launch_dimensions); 686 687 StatusOr<std::unique_ptr<Thunk>> BuildKernelThunk( 688 mlir::Operation* op, Thunk::ThunkInfo thunk_info, 689 std::vector<llvm_ir::IrArray>* ir_arrays, 690 const LaunchDimensions& launch_dimensions); 691 692 // Returns a thunk that, given a reduce or select-and-scatter op, 693 // initializes its memory to the appropriate initial value. 694 std::unique_ptr<Thunk> BuildConstantInitializerThunk( 695 absl::Span<const uint8_t> init_value, const BufferAllocation::Slice& dest, 696 const Shape& output_shape); 697 698 StatusOr<std::unique_ptr<Thunk>> TryBuildConstantInitializerThunk( 699 mlir::Value init_value, mlir::Value dest); 700 701 StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunk(mlir::Operation* op, 702 mlir::Value init_value, 703 mlir::Value dest); 704 StatusOr<std::unique_ptr<Thunk>> BuildFusedInitializerThunk( 705 mlir::lmhlo::FusionOp fusion, int output_index); 706 707 // Returns a WhileThunk that invokes thunk sequences for 'condition' and 708 // 'body' sub-computations of while instruction 'hlo'. 709 StatusOr<std::unique_ptr<Thunk>> BuildWhileThunk( 710 mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info); 711 712 // Returns a ForThunk which executes 'loop_limit' invocations of a thunk 713 // sequence from the 'body' sub-computation of the while instruction 'hlo'. 714 StatusOr<std::unique_ptr<Thunk>> BuildForThunk( 715 mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info, 716 const int64_t loop_limit); 717 718 // Returns a ConditionalThunk which executes the thunk sequence for the 719 // 'branch_computation' corresponding to the predicate/branch_index of the 720 // given conditional instruction. 721 StatusOr<std::unique_ptr<Thunk>> BuildConditionalThunk( 722 const HloInstruction* conditional); 723 724 // Emits the LLVM values for thread_id, thread_id.x, thread_id.y and lane 725 // id. 726 // 727 // Returns a struct containting these values. 728 // 729 // In the presence of thread scaling in tiling scheme may return early if the 730 // combination of thread_id/block_id does not correspond to a real block. 731 // Assumes the current function returns void. 732 StatusOr<ThreadIdInfo> EmitThreadIdInfo(const TilingScheme& tiling_scheme, 733 llvm::Type* index_ty); 734 // Emit __syncthreads(), synchronization barrier for all threads in a block. 735 llvm::CallInst* EmitSyncThreads(); 736 737 // Emits current thread id with the given type. 738 // 739 // Sets the return value range to [0, threads_per_block). 740 llvm::Value* EmitThreadId(int64_t threads_per_block, llvm::Type* index_ty); 741 742 // Emits current block id. 743 llvm::Value* EmitBlockId(int32_t num_blocks, llvm::Type* index_ty); 744 745 // Prints a given format string with the given arguments, prefixed with 746 // thread id and block id, and postfixed with a newline. 747 // 748 // `thread_id_filter` and `block_id_filter`: if provided, restrict printing 749 // to only given thread and/or block id. 750 void EmitPrintfWithThreadId( 751 absl::string_view fmt, absl::Span<llvm::Value* const> arguments, 752 std::optional<int64_t> thread_id_filter = std::nullopt, 753 std::optional<int64_t> block_id_filter = std::nullopt); 754 755 StatusOr<HloComputation*> GetOrCreateSubComputationFromRegion( 756 mlir::Region* region, bool is_fusion); 757 758 // Returns the last generated thunk. LastThunk()759 Thunk* LastThunk() const { return thunk_sequence_.back().get(); } 760 761 Status AssertNonDeterminismIsOkay(const std::string& op_name); 762 763 // The thunk sequence this IrEmitter generates for the input computation. 764 ThunkSequence thunk_sequence_; 765 766 // Maps all-reduce-start ops to their thunk so done can access the thunk. 767 absl::flat_hash_map<mlir::Operation*, NcclAllReduceStartThunk*> 768 all_reduce_start_thunks_; 769 770 // Begin optional members for XLA HLO -> LMHLO: 771 absl::flat_hash_map<const mlir::Region*, std::unique_ptr<HloModule>> 772 scratch_nested_computations_; 773 // End optional members for XLA HLO -> LMHLO. 774 775 // __shared__ memory uses a different address space, so we cast it to 776 // global address space before writing or reading. 777 llvm::Value* CastSharedToGlobal(llvm::Value* input, llvm::Type* element_type, 778 llvm::Twine name = ""); 779 780 // Returns the ShapedSlices for the given operands. 781 StatusOr<std::vector<ShapedSlice>> GetShapedSlices( 782 mlir::Operation::operand_range operands); 783 784 // Returns the buffer allocation Slice for the given operands. 785 StatusOr<std::vector<BufferAllocation::Slice>> GetSlices( 786 mlir::Operation::operand_range operands); 787 }; 788 789 } // namespace gpu 790 } // namespace xla 791 792 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ 793