xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
17 
18 #include <stdlib.h>
19 
20 #include <atomic>
21 #include <functional>
22 #include <iterator>
23 #include <memory>
24 #include <string>
25 #include <utility>
26 #include <variant>
27 #include <vector>
28 
29 #include "absl/strings/numbers.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/types/variant.h"
32 #include "llvm/AsmParser/Parser.h"
33 #include "llvm/Bitcode/BitcodeReader.h"
34 #include "llvm/Bitcode/BitcodeWriter.h"
35 #include "llvm/IR/DiagnosticInfo.h"
36 #include "llvm/IR/DiagnosticPrinter.h"
37 #include "llvm/IR/LLVMContext.h"
38 #include "llvm/IR/Module.h"
39 #include "llvm/IR/Verifier.h"
40 #include "llvm/Transforms/Utils/SplitModule.h"
41 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
42 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
43 #include "mlir/Dialect/GPU/Transforms/Passes.h"  // from @llvm-project
44 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
45 #include "mlir/InitAllDialects.h"  // from @llvm-project
46 #include "mlir/Pass/PassManager.h"  // from @llvm-project
47 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
48 #include "mlir/Transforms/LocationSnapshot.h"  // from @llvm-project
49 #include "mlir/Transforms/Passes.h"  // from @llvm-project
50 #include "tensorflow/compiler/mlir/utils/name_utils.h"
51 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
52 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
53 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/gpu_passes.h"
54 #include "tensorflow/compiler/xla/protobuf_util.h"
55 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
56 #include "tensorflow/compiler/xla/service/all_gather_broadcast_reorder.h"
57 #include "tensorflow/compiler/xla/service/all_gather_combiner.h"
58 #include "tensorflow/compiler/xla/service/all_gather_decomposer.h"
59 #include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
60 #include "tensorflow/compiler/xla/service/all_reduce_contiguous.h"
61 #include "tensorflow/compiler/xla/service/all_reduce_folder.h"
62 #include "tensorflow/compiler/xla/service/all_reduce_reassociate.h"
63 #include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
64 #include "tensorflow/compiler/xla/service/async_collective_creator.h"
65 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
66 #include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
67 #include "tensorflow/compiler/xla/service/bitcast_decomposer.h"
68 #include "tensorflow/compiler/xla/service/bitcast_dtypes_expander.h"
69 #include "tensorflow/compiler/xla/service/broadcast_canonicalizer.h"
70 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
71 #include "tensorflow/compiler/xla/service/call_inliner.h"
72 #include "tensorflow/compiler/xla/service/collectives_schedule_linearizer.h"
73 #include "tensorflow/compiler/xla/service/comparison_expander.h"
74 #include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
75 #include "tensorflow/compiler/xla/service/conditional_simplifier.h"
76 #include "tensorflow/compiler/xla/service/convert_mover.h"
77 #include "tensorflow/compiler/xla/service/convolution_4d_expander.h"
78 #include "tensorflow/compiler/xla/service/convolution_pred_expander.h"
79 #include "tensorflow/compiler/xla/service/copy_insertion.h"
80 #include "tensorflow/compiler/xla/service/dot_decomposer.h"
81 #include "tensorflow/compiler/xla/service/dot_merger.h"
82 #include "tensorflow/compiler/xla/service/dump.h"
83 #include "tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h"
84 #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
85 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
86 #include "tensorflow/compiler/xla/service/eigh_expander.h"
87 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
88 #include "tensorflow/compiler/xla/service/gather_expander.h"
89 #include "tensorflow/compiler/xla/service/gather_simplifier.h"
90 #include "tensorflow/compiler/xla/service/gpu/alias_passthrough_params.h"
91 #include "tensorflow/compiler/xla/service/gpu/all_reduce_blueconnect.h"
92 #include "tensorflow/compiler/xla/service/gpu/fusion_bitcast_lift.h"
93 #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
94 #include "tensorflow/compiler/xla/service/gpu/gemm_broadcast_folding_rewriter.h"
95 #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
96 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
97 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h"
98 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
99 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
100 #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.h"
101 #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
102 #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
103 #include "tensorflow/compiler/xla/service/gpu/gpu_reduce_scatter_creator.h"
104 #include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h"
105 #include "tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h"
106 #include "tensorflow/compiler/xla/service/gpu/gpu_shape_verifier.h"
107 #include "tensorflow/compiler/xla/service/gpu/hlo_fusion_stats.h"
108 #include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h"
109 #include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h"
110 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
111 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
112 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
113 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
114 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
115 #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
116 #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h"
117 #include "tensorflow/compiler/xla/service/gpu/metrics.h"
118 #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
119 #include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
120 #include "tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h"
121 #include "tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h"
122 #include "tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h"
123 #include "tensorflow/compiler/xla/service/gpu/reduction_splitter.h"
124 #include "tensorflow/compiler/xla/service/gpu/runtime_intrinsics.h"
125 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
126 #include "tensorflow/compiler/xla/service/gpu/target_constants.h"
127 #include "tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.h"
128 #include "tensorflow/compiler/xla/service/gpu/variadic_op_splitter.h"
129 #include "tensorflow/compiler/xla/service/hlo_computation.h"
130 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
131 #include "tensorflow/compiler/xla/service/hlo_cse.h"
132 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
133 #include "tensorflow/compiler/xla/service/hlo_dce.h"
134 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
135 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
136 #include "tensorflow/compiler/xla/service/hlo_parser.h"
137 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
138 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
139 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
140 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
141 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
142 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
143 #include "tensorflow/compiler/xla/service/layout_normalization.h"
144 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
145 #include "tensorflow/compiler/xla/service/logistic_expander.h"
146 #include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h"
147 #include "tensorflow/compiler/xla/service/operand_upcaster.h"
148 #include "tensorflow/compiler/xla/service/optimization_barrier_expander.h"
149 #include "tensorflow/compiler/xla/service/qr_expander.h"
150 #include "tensorflow/compiler/xla/service/real_imag_expander.h"
151 #include "tensorflow/compiler/xla/service/reduce_decomposer.h"
152 #include "tensorflow/compiler/xla/service/reduce_scatter_combiner.h"
153 #include "tensorflow/compiler/xla/service/reshape_decomposer.h"
154 #include "tensorflow/compiler/xla/service/reshape_mover.h"
155 #include "tensorflow/compiler/xla/service/result_caster.h"
156 #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
157 #include "tensorflow/compiler/xla/service/rng_expander.h"
158 #include "tensorflow/compiler/xla/service/scatter_simplifier.h"
159 #include "tensorflow/compiler/xla/service/sharding_propagation.h"
160 #include "tensorflow/compiler/xla/service/sharding_remover.h"
161 #include "tensorflow/compiler/xla/service/simplify_fp_conversions.h"
162 #include "tensorflow/compiler/xla/service/slice_sinker.h"
163 #include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
164 #include "tensorflow/compiler/xla/service/sort_simplifier.h"
165 #include "tensorflow/compiler/xla/service/spmd/stateful_rng_spmd_partitioner.h"
166 #include "tensorflow/compiler/xla/service/stable_sort_expander.h"
167 #include "tensorflow/compiler/xla/service/transpose_folding.h"
168 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
169 #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
170 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
171 #include "tensorflow/compiler/xla/service/while_loop_trip_count_annotator.h"
172 #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h"
173 #include "tensorflow/compiler/xla/status_macros.h"
174 #include "tensorflow/compiler/xla/types.h"
175 #include "tensorflow/compiler/xla/util.h"
176 #include "tensorflow/core/lib/core/status.h"
177 #include "tensorflow/core/platform/blocking_counter.h"
178 #include "tensorflow/core/platform/casts.h"
179 #include "tensorflow/core/platform/env.h"
180 #include "tensorflow/core/platform/logging.h"
181 #include "tensorflow/core/platform/regexp.h"
182 #include "tensorflow/core/platform/statusor.h"
183 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
184 #include "tensorflow/core/platform/threadpool.h"
185 #include "tensorflow/core/profiler/lib/traceme.h"
186 #include "tensorflow/core/util/env_var.h"
187 
188 #if XLA_ENABLE_XLIR
189 #include "tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/pass_utils.h"
190 #include "tensorflow/compiler/xla/mlir/transforms/runtime/compilation_pipeline.h"
191 #include "tensorflow/compiler/xla/runtime/jit_executable.h"
192 #include "tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.h"
193 #endif  // XLA_ENABLE_XLIR
194 
195 namespace xla {
196 namespace gpu {
197 namespace {
198 
199 class GpuBfloat16Support : public BFloat16Support {
200  public:
GpuBfloat16Support(bool supports_matrix_multiplication,se::StreamExecutor * stream_exec)201   explicit GpuBfloat16Support(bool supports_matrix_multiplication,
202                               se::StreamExecutor* stream_exec)
203       : supports_matrix_multiplication_(supports_matrix_multiplication),
204         stream_exec_(stream_exec) {}
205 
SupportsBF16Operand(const HloInstruction & hlo,int64_t operand_index) const206   bool SupportsBF16Operand(const HloInstruction& hlo,
207                            int64_t operand_index) const override {
208     return BFloat16Support::SupportsBF16Operand(hlo, operand_index) ||
209            IsSupported(hlo);
210   }
211 
212   // Returns whether the backend supports BF16 output for the HLO instruction.
SupportsBF16Output(const HloInstruction & hlo) const213   bool SupportsBF16Output(const HloInstruction& hlo) const override {
214     return BFloat16Support::SupportsBF16Output(hlo) || IsSupported(hlo);
215   }
216 
217  private:
IsSupported(const HloInstruction & hlo) const218   bool IsSupported(const HloInstruction& hlo) const {
219     switch (hlo.opcode()) {
220       // Collective ops.
221       case HloOpcode::kAllGather:
222       case HloOpcode::kAllReduce:
223       case HloOpcode::kAllReduceStart:
224       case HloOpcode::kAllReduceDone:
225       case HloOpcode::kAllToAll:
226       case HloOpcode::kCollectivePermute:
227       case HloOpcode::kReduceScatter:
228       // Data movement only ops.
229       case HloOpcode::kBroadcast:
230       case HloOpcode::kConcatenate:
231       case HloOpcode::kCopy:
232       case HloOpcode::kDynamicSlice:
233       case HloOpcode::kDynamicUpdateSlice:
234       case HloOpcode::kGather:
235       case HloOpcode::kPad:
236       case HloOpcode::kReshape:
237       case HloOpcode::kReverse:
238       case HloOpcode::kScatter:
239       case HloOpcode::kSelect:
240       case HloOpcode::kSelectAndScatter:
241       case HloOpcode::kSlice:
242       case HloOpcode::kTranspose:
243       // Other special ops.
244       case HloOpcode::kBitcast:
245         return true;
246       case HloOpcode::kConvolution:
247         return IsConvBF16Supported();
248       default:
249         return supports_matrix_multiplication_ &&
250                gpu::IsMatrixMultiplication(hlo);
251     }
252   }
253 
IsConvBF16Supported() const254   bool IsConvBF16Supported() const {
255     if (se::dnn::DnnSupport* dnn = stream_exec_->AsDnn()) {
256       se::port::StatusOr<se::dnn::VersionInfo> cudnn_version =
257           dnn->GetVersion();
258       return cudnn_version.ok() &&
259              (cudnn_version->major_version() > 8 ||
260               (cudnn_version->major_version() == 8 &&
261                cudnn_version->minor_version() >= 2)) &&
262              stream_exec_->GetDeviceDescription()
263                  .cuda_compute_capability()
264                  .IsAtLeast(se::CudaComputeCapability::AMPERE);
265     }
266     return false;
267   }
268 
269   bool supports_matrix_multiplication_;
270   se::StreamExecutor* stream_exec_;
271 };
272 
GetSizeOfShape(const Shape & shape,int pointer_size)273 int64_t GetSizeOfShape(const Shape& shape, int pointer_size) {
274   if (shape.is_static() || shape.IsTuple()) {
275     return ShapeUtil::ByteSizeOf(shape, pointer_size);
276   }
277   // Each dynamic dimension size is represented as a S32.
278   int64_t metadata_size = sizeof(int32_t) * shape.dimensions_size();
279   return ShapeUtil::ByteSizeOf(shape, pointer_size) + metadata_size;
280 }
281 
ConvIsLowerable(HloInstruction * conv)282 bool ConvIsLowerable(HloInstruction* conv) {
283   return conv_matchers::CanImplementAsGpuForwardConv(conv) ||
284          std::get<0>(conv_matchers::MatchBackwardFilter(conv)) ||
285          std::get<0>(conv_matchers::MatchBackwardInput(conv));
286 }
287 
288 }  // end anonymous namespace
289 
290 using OwnedThunkSequence = GpuExecutable::OwnedThunkSequence;
291 using OwnedJitRtProgram = GpuExecutable::OwnedJitRtProgram;
292 
LoadExecutable(Compiler * compiler,se::StreamExecutor * executor) const293 StatusOr<std::unique_ptr<Executable>> JitRtAotCompilationResult::LoadExecutable(
294     Compiler* compiler, se::StreamExecutor* executor) const {
295   TF_ASSIGN_OR_RETURN(
296       HloModuleConfig hlo_module_config,
297       HloModule::CreateModuleConfigFromProto(
298           jitrt_executable_.hlo_module_proto(), GetDebugOptionsFromFlags()));
299   TF_ASSIGN_OR_RETURN(
300       std::unique_ptr<HloModule> hlo_module,
301       HloModule::CreateFromProto(jitrt_executable_.hlo_module_proto(),
302                                  hlo_module_config));
303   auto gpu_compiler = tensorflow::down_cast<GpuCompiler*>(compiler);
304   return GpuExecutable::LoadFromObjFile(
305       std::move(hlo_module), jitrt_executable_.obj_file(),
306       jitrt_executable_.mlir_module(), jitrt_executable_.entry_func_attrs(),
307       GetDebugOptionsFromFlags(), gpu_compiler->GetGpuVersion(executor),
308       executor);
309 }
310 
GpuCompiler(se::Platform::Id platform_id,const char * target_triple,const char * data_layout)311 GpuCompiler::GpuCompiler(se::Platform::Id platform_id,
312                          const char* target_triple, const char* data_layout)
313     : platform_id_(platform_id),
314       target_triple_(target_triple),
315       data_layout_(data_layout),
316       pointer_size_(llvm::DataLayout(data_layout)
317                         .getPointerSize(0 /* default address space */)) {}
318 
319 namespace {
320 // Adds the HloVerifier for GPU to the given pipeline.
AddHloVerifier(HloPassPipeline * pipeline,HloVerifierOpts && opts={},bool debug_only=false)321 void AddHloVerifier(HloPassPipeline* pipeline, HloVerifierOpts&& opts = {},
322                     bool debug_only = false) {
323   std::unique_ptr<TargetVerifierMetadata> verifier_metadata =
324       std::make_unique<GpuVerifierMetadata>(std::move(opts));
325   if (debug_only) {
326     pipeline->AddInvariantCheckerDebug<HloVerifier>(
327         std::move(verifier_metadata), "hlo verifier (debug)");
328   } else {
329     pipeline->AddInvariantChecker<HloVerifier>(std::move(verifier_metadata),
330                                                "hlo verifier");
331   }
332 }
333 }  // namespace
334 
335 // Runs optimization passes on the given HLO module.
OptimizeHloModule(HloModule * hlo_module,se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * device_allocator)336 Status GpuCompiler::OptimizeHloModule(
337     HloModule* hlo_module, se::StreamExecutor* stream_exec,
338     se::DeviceMemoryAllocator* device_allocator) {
339   const DebugOptions& debug_options = hlo_module->config().debug_options();
340 
341   AlgebraicSimplifierOptions layout_insensitive_algsimp_opts({},
342                                                              ConvIsLowerable);
343   // "slow" minmax means we propagate nan.
344   layout_insensitive_algsimp_opts.set_minmax_propagate_nan(
345       !debug_options.xla_gpu_enable_fast_min_max());
346 
347   const se::Platform* platform = stream_exec->platform();
348   if (platform->Name() == "ROCM") {
349     // SwapConvOperands does not yet work on ROCM
350     layout_insensitive_algsimp_opts.set_enable_conv_operand_swap(false);
351   }
352 
353   if (hlo_module->config().use_spmd_partitioning()) {
354     HloPassPipeline spmd_pipeline("spmd-partitioner");
355     AddHloVerifier(&spmd_pipeline);
356     const int64_t num_partitions = hlo_module->config().num_partitions();
357     if (num_partitions > 1) {
358       // Run some IR cleanup passes before running the SPMD partitioning
359       // passes.
360       spmd_pipeline.AddPass<CallInliner>();
361       spmd_pipeline.AddPass<ZeroSizedHloElimination>();
362       spmd_pipeline.AddPass<ConditionalCanonicalizer>();
363 
364       HloPassPipeline& spmd_simplify =
365           spmd_pipeline.AddPass<HloPassFix<HloPassPipeline>>("spmd-simplify");
366 
367       spmd_simplify.AddPass<AlgebraicSimplifier>(
368           layout_insensitive_algsimp_opts);
369 
370       spmd_simplify.AddPass<SortSimplifier>();
371       spmd_simplify.AddPass<TupleSimplifier>();
372       spmd_simplify.AddPass<ScatterSimplifier>();
373       spmd_simplify.AddPass<ScatterExpander>(
374           ScatterExpander::kEliminateSimpleScatters);
375       spmd_simplify.AddPass<GatherSimplifier>();
376       spmd_simplify.AddPass<GatherExpander>(
377           GatherExpander::kEliminateSimpleGathers);
378       spmd_simplify.AddPass<WhileLoopConstantSinking>();
379       spmd_simplify.AddPass<WhileLoopSimplifier>();
380 
381       spmd_simplify.AddPass<ReshapeMover>();
382       spmd_simplify.AddPass<HloConstantFolding>();
383       spmd_simplify.AddPass<ConditionalSimplifier>();
384       spmd_simplify.AddPass<HloDCE>();
385 
386       spmd_pipeline.AddPass<ShardingPropagation>(
387           /*is_spmd=*/true, /*propagate_metadata=*/false,
388           hlo_module->config().allow_spmd_sharding_propagation_to_output());
389       spmd_pipeline.AddPass<spmd::StatefulRngSpmdPartitioner>(
390           num_partitions, hlo_module->config().replica_count());
391     } else {
392       // Remove redundant sharding ops when partition_count == 1.
393       spmd_pipeline.AddPass<ShardingRemover>();
394       spmd_pipeline.AddPass<HloDCE>();
395     }
396     TF_RETURN_IF_ERROR(spmd_pipeline.Run(hlo_module).status());
397   }
398 
399   {
400     HloPassPipeline pipeline("optimization");
401     AddHloVerifier(&pipeline);
402     pipeline.AddPass<AllToAllDecomposer>();
403 
404     HloPredicate upcaster_filter = [&](const HloInstruction* instr) {
405       return !stream_exec->GetDeviceDescription()
406                   .cuda_compute_capability()
407                   .IsAtLeast(se::CudaComputeCapability::VOLTA) ||
408              !gpu::IsMatrixMultiplication(*instr);
409     };
410 
411     pipeline.AddPass<OperandUpcaster>(upcaster_filter);
412     pipeline.AddPass<ResultCaster>(upcaster_filter);
413 
414     // Expand random number generation.
415     pipeline.AddPass<RngExpander>();
416     pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
417 
418     // Comparison total order expander
419     pipeline.AddPass<ComparisonExpander>();
420 
421     // Remove zero-sized HLO from the input so that other passes don't have to
422     // handle it.
423     pipeline.AddPass<ZeroSizedHloElimination>();
424 
425     if (debug_options.xla_gpu_deterministic_ops()) {
426       // Scatter is nondeterministic, so eliminate all Scatters.
427       pipeline.AddPass<ScatterExpander>(ScatterExpander::kEliminateAllScatters);
428     } else {
429       // Only Scatters unsupported on XLA:GPU are eliminated.
430       pipeline.AddPass<GpuScatterExpander>();
431     }
432     // TODO(phawkins): replace QR and Eigh decompositions with calls to
433     // cuSOLVER.
434     pipeline.AddPass<QrExpander>();
435     pipeline.AddPass<EighExpander>();
436 
437     pipeline.AddPass<DynamicIndexSplitter>();
438 
439     // TODO(b/64094172): make Call work on GPU instead of inlining.
440     pipeline.AddPass<CallInliner>();
441 
442     pipeline.AddPass<DotDecomposer>();
443 
444     pipeline.AddPass<Convolution4DExpander>();
445 
446     // Replace PRED convolutions with F16.
447     pipeline.AddPass<ConvolutionPredExpander>();
448 
449     // Expand the sort op to support stable sorting if required.
450     pipeline.AddPass<StableSortExpander>();
451 
452     GpuBfloat16Support bf16(/*supports_matrix_multiplication=*/true,
453                             stream_exec);
454     pipeline.AddPass<BFloat16Normalization>(&bf16);
455 
456     // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization.
457     if (debug_options.xla_gpu_simplify_all_fp_conversions())
458       pipeline.AddPass<SimplifyFPConversions>();
459 
460     pipeline.AddPass<BatchNormExpander>(
461         /*rewrite_training_op=*/true,
462         /*rewrite_inference_op=*/true,
463         /*rewrite_grad_op=*/true);
464 
465     pipeline.AddPass<LogisticExpander>(
466         /*expansion_type=*/LogisticExpansionType::kExp);
467     pipeline.AddPass<ConditionalCanonicalizer>();
468     pipeline.AddPass<DynamicDimensionSimplifier>();
469 
470     DynamicPadderOptions dynamic_padder_options;
471 
472     switch (hlo_module->config().debug_options().xla_gpu_shape_checks()) {
473       case DebugOptions::IGNORE:
474         dynamic_padder_options.shape_check_mode =
475             DynamicDimensionInference::ShapeCheckMode::kIgnore;
476         break;
477       case DebugOptions::RUNTIME: {
478         dynamic_padder_options.shape_check_mode =
479             DynamicDimensionInference::ShapeCheckMode::kRuntime;
480         dynamic_padder_options.assertion_generator = [&](HloInstruction* inst) {
481           auto created = Cast<HloCustomCallInstruction>(
482               inst->parent()->AddInstruction(HloInstruction::CreateCustomCall(
483                   ShapeUtil::MakeTokenShape(), {inst},
484                   kXlaGpuAssertCustomCallTag,
485                   "Buffers have different size at runtime",
486                   API_VERSION_STATUS_RETURNING)));
487           created->set_custom_call_has_side_effect(true);
488         };
489         break;
490       }
491       case DebugOptions::COMPILE_TIME:
492         dynamic_padder_options.shape_check_mode =
493             DynamicDimensionInference::ShapeCheckMode::kCompileTime;
494         break;
495       default:
496         LOG(FATAL) << "Unreachable";
497     }
498 
499     pipeline.AddPass<DynamicPadder>(dynamic_padder_options);
500 
501     // Build simplification pipeline.  The passes in here are run to a fixed
502     // point.
503     [&, &pipeline =
504             pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification")] {
505       AddHloVerifier(&pipeline, HloVerifierOpts{}, /*debug_only=*/true);
506 
507       // BatchNormExpander can create zero-sized ops, so zero-sized HLO
508       // elimination has to come after that pass.
509       pipeline.AddPass<ZeroSizedHloElimination>();
510 
511       pipeline.AddPass<GatherSimplifier>();
512       pipeline.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
513       pipeline.AddPass<ScatterSimplifier>();
514       pipeline.AddPass<ScatterExpander>(
515           ScatterExpander::kEliminateSimpleScatters);
516       pipeline.AddPass<AlgebraicSimplifier>(layout_insensitive_algsimp_opts);
517       pipeline.AddPass<BitcastDtypesExpander>();
518       // AlgebraicSimplifier may add contracting dimensions to a dot.
519       pipeline.AddPass<DotDecomposer>();
520       // Only merge "smallish" dots.  This threshold was not set carefully, but
521       // so far we know that 1mb is too small.
522       pipeline.AddPass<DotMerger>(/*max_size_to_merge=*/int64_t{16} << 20);
523       pipeline.AddPass<SortSimplifier>();
524       pipeline.AddPass<TupleSimplifier>();
525       pipeline.AddPass<WhileLoopConstantSinking>();
526       pipeline.AddPass<WhileLoopSimplifier>();
527 
528       // TODO(b/134075051): Re-enable after b/134075051 is fixed.
529       // pipeline.AddPass<SliceSinker>();
530 
531       pipeline.AddPass<ReshapeMover>();
532       pipeline.AddPass<HloConstantFolding>();
533       pipeline.AddPass<ConditionalSimplifier>();
534       pipeline.AddPass<RealImagExpander>();
535       pipeline.AddPass<TransposeFolding>(CanFoldTransposeOperandIntoDot);
536       pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
537       pipeline.AddPass<HloDCE>();
538     }();
539 
540     // ConvertMover and ReshapeMover fight with each other: ConvertMover wants
541     // to move some converts down the graph, but ReshapeMover wants to move them
542     // up the graph.  As a compromise, let ReshapeMover run to a fixed point,
543     // and then run ConvertMover + algsimp to a fixed point.
544     [&, &pipeline =
545             pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification-2")] {
546       pipeline.AddPass<ConvertMover>();
547       pipeline.AddPass<AlgebraicSimplifier>(layout_insensitive_algsimp_opts);
548     }();
549 
550     // Run WhileLoopTripCountAnnotator at the end of the simplification
551     // pipeline, before layout assignment and fusion.  This pass does some
552     // pattern-matching on while bodies/conditions, and this is where the HLO is
553     // "nicest".
554     //
555     // It's important that we don't make semantic changes (e.g. unrolling) to
556     // any `while` loops after this point, because otherwise the trip-count
557     // annotations added by this pass may not be correct after the
558     // modifications.
559     pipeline.AddPass<WhileLoopTripCountAnnotator>();
560     TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
561   }
562 
563   // Optimize collectives generated by SPMD partitioning. Enable these passes
564   // otherwise as well so that all collectives can get these optimizations.
565   {
566     HloPassPipeline collectives_pipeline("collective-optimizations");
567     collectives_pipeline.AddPass<AllReduceFolder>();
568     collectives_pipeline.AddPass<ReduceScatterCreator>();
569     collectives_pipeline.AddPass<AllReduceReassociate>();
570 
571     // Run algebraic simplifier to reshape(broadcast) into a broadcast when
572     // the reshape is just adding a unit dimension. This will help with the
573     // AllGatherBroadcastReorder pass.
574     collectives_pipeline.AddPass<AlgebraicSimplifier>(
575         layout_insensitive_algsimp_opts);
576 
577     collectives_pipeline.AddPass<AllGatherBroadcastReorder>();
578     TF_RETURN_IF_ERROR(collectives_pipeline.Run(hlo_module).status());
579   }
580 
581   // Run target-specific HLO optimization passes for convolution
582   // canonicalization.
583   TF_RETURN_IF_ERROR(OptimizeHloConvolutionCanonicalization(
584       hlo_module, stream_exec, device_allocator));
585 
586   {
587     // Run layout assignment in a separate pipeline from
588     // "post-layout-assignment" because we want everything after layout
589     // assignment to have a layout-sensitive invariant-checker, but
590     // HloPassPipeline also runs its invariant checker before any passes are
591     // run, meaning, the pipeline that contains layout assignment cannot contain
592     // a layout-sensitive verifier!
593     HloPassPipeline pipeline("layout assignment");
594     // Layout assignment uses alias analysis, which requires the call graph to
595     // be flattened.
596     pipeline.AddPass<FlattenCallGraph>();
597     ChannelLayoutConstraints layout_constraints;
598     pipeline.AddPass<GpuLayoutAssignment>(
599         hlo_module->mutable_entry_computation_layout(), stream_exec,
600         &layout_constraints);
601     TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
602   }
603 
604   // Run target-specific HLO optimization passes after layout assignment.
605   TF_RETURN_IF_ERROR(OptimizeHloPostLayoutAssignment(hlo_module, stream_exec,
606                                                      device_allocator));
607 
608   {
609     HloPassFix<HloPassPipeline> fusion("fusion");
610     // We try to split variadic ops with many parameters into several such ops
611     // to avoid exceeding the parameter space.
612     fusion.AddPass<VariadicOpSplitter>();
613     AddHloVerifier(
614         &fusion,
615         HloVerifierOpts{}.MakeLayoutSensitive().WithInstructionCanChangeLayout(
616             LayoutAssignment::InstructionCanChangeLayout),
617         /*debug_only=*/true);
618     fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
619     fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
620     fusion.AddPass<FusionMerger>();
621     fusion.AddPass<GpuMultiOutputFusion>();
622     fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
623                            /*only_fusion_computations=*/true);
624     fusion.AddPass<HloDCE>();
625     TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
626   }
627 
628   {
629     HloPassFix<HloPassPipeline> horizontal_fusion("horizontal fusion");
630     horizontal_fusion.AddPass<GpuHorizontalLoopFusion>();
631     horizontal_fusion.AddPass<GpuHorizontalInputFusion>();
632     // FusionBitcastLift must be after InstructionFusion, as it undoes
633     // part of it.
634     // TODO(b/209005695) Renable once the bug is fixed.
635     // horizontal_fusion.AddPass<FusionBitcastLift>();
636     horizontal_fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
637                                       /*only_fusion_computations=*/true);
638     horizontal_fusion.AddPass<HloDCE>();
639     TF_RETURN_IF_ERROR(horizontal_fusion.Run(hlo_module).status());
640   }
641 
642   if (VLOG_IS_ON(2)) {
643     HloFusionStatsVisitor stats;
644     TF_RETURN_IF_ERROR(hlo_module->entry_computation()->Accept(&stats));
645     VLOG(2) << stats.ToString();
646   }
647 
648   {
649     HloPassPipeline pipeline("post-fusion optimization");
650     pipeline.AddPass<AllGatherCombiner>(
651         /*combine_threshold_in_bytes=*/1024 * 1024 * 1024,
652         /*combine_threshold_count=*/256);
653     pipeline.AddPass<AllReduceCombiner>(
654         debug_options.xla_gpu_all_reduce_combine_threshold_bytes(),
655         /*combine_threshold_count=*/256);
656     pipeline.AddPass<ReduceScatterCombiner>(
657         /*combine_threshold_in_bytes=*/30 * 1024 * 1024,
658         /*combine_threshold_count=*/256);
659 
660     if (debug_options.xla_gpu_all_reduce_contiguous()) {
661       pipeline.AddPass<AllReduceContiguous>();
662     }
663 
664     int32_t blueconnect_num_devices_per_host =
665         debug_options.xla_gpu_all_reduce_blueconnect_num_devices_per_host();
666     if (blueconnect_num_devices_per_host > 0) {
667       pipeline.AddPass<AllReduceBlueConnect>(blueconnect_num_devices_per_host);
668     }
669 
670     if (debug_options.xla_gpu_enable_async_all_reduce()) {
671       AsyncCollectiveCreator::CollectiveCreatorConfig config;
672       config.convert_all_reduce = [](const HloInstruction*) { return true; };
673       pipeline.AddPass<AsyncCollectiveCreator>(std::move(config));
674     }
675 
676     pipeline.AddPass<CollectivesScheduleLinearizer>();
677 
678     AlgebraicSimplifierOptions options = layout_insensitive_algsimp_opts;
679     options.set_is_layout_sensitive(true);
680     pipeline.AddPass<AlgebraicSimplifier>(options);
681     pipeline.AddPass<OptimizationBarrierExpander>();
682     pipeline.AddPass<BitcastDecomposer>();
683     pipeline.AddPass<TupleSimplifier>();
684 
685     TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
686   }
687 
688   return OkStatus();
689 }
690 
691 // Modifies the given HLO module so that it will be accepted by IrEmitter.
692 // Unlike optimization passes, the passes are necessary for correctness.
PrepareHloModuleForIrEmitting(HloModule * hlo_module)693 Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
694   // In some cases, we have to place the result of an instruction in a temporary
695   // buffer. For instance, the buffer that holds an external parameter is
696   // assumed immutable at this point, and should not be reused for output
697   // (b/27180329). Therefore, in that case, we set the output to be a copy of
698   // the parameter.
699   HloPassPipeline pipeline("GPU-ir-emit-prepare");
700   AddHloVerifier(
701       &pipeline,
702       HloVerifierOpts{}.MakeLayoutSensitive().WithInstructionCanChangeLayout(
703           LayoutAssignment::InstructionCanChangeLayout),
704       /*debug_only=*/true);
705 
706   // Copy insertion should be performed immediately before IR emission to avoid
707   // inserting unnecessary copies (later pass adds an instruction which
708   // materializes the value) or missing a necessary copy (later pass removes an
709   // instruction which materializes a value). DCE must be run immediately before
710   // (and sometime after) copy insertion, to avoid dead code from interfering
711   // with the rewrites.
712   pipeline.AddPass<HloDCE>();
713   if (hlo_module->config().alias_passthrough_params()) {
714     pipeline.AddPass<AliasPassthroughParams>();
715   }
716   pipeline.AddPass<LoopScheduleLinearizer>(GetCanShareBuffer());
717   pipeline.AddPass<CopyInsertion>(GetCanShareBuffer());
718   pipeline.AddPass<GpuSanitizeConstantNames>();
719   return pipeline.Run(hlo_module).status();
720 }
721 
OptimizeHloPostLayoutAssignment(HloModule * hlo_module,se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * device_allocator)722 Status GpuCompiler::OptimizeHloPostLayoutAssignment(
723     HloModule* hlo_module, se::StreamExecutor* stream_exec,
724     se::DeviceMemoryAllocator* device_allocator) {
725   const DebugOptions& debug_options = hlo_module->config().debug_options();
726 
727   {
728     HloPassPipeline pipeline("hlo normalization");
729     pipeline.AddPass<ReshapeDecomposer>();
730     pipeline.AddPass<ReduceDecomposer>([&](const HloInstruction* r) {
731       return IsReductionFromOrToContiguousDimensions(*r);
732     });
733     if (hlo_module->config().debug_options().xla_gpu_normalize_layouts()) {
734       pipeline.AddPass<LayoutNormalization>();
735     }
736     pipeline.AddPass<BroadcastCanonicalizer>();
737     TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
738   }
739 
740   HloPassPipeline pipeline("post-layout_assignment");
741   AddHloVerifier(&pipeline,
742                  HloVerifierOpts{}
743                      .MakeLayoutSensitive()
744                      .WithInstructionCanChangeLayout(
745                          LayoutAssignment::InstructionCanChangeLayout)
746                      .VerifyBroadcastDimensionsOrder()
747                      .VerifyReshapeIsBitcast(),
748                  /*debug_only=*/true);
749 
750   pipeline.AddPass<ReductionDegenerateDimRemover>();
751   pipeline.AddPass<ReductionLayoutNormalizer>();
752   pipeline.AddPass<ReductionDimensionGrouper>();
753   pipeline.AddPass<HloPassFix<ReductionSplitter>>();
754   pipeline.AddPass<HloPassFix<GpuTreeReductionRewriter>>(
755       stream_exec->GetDeviceDescription().cuda_compute_capability());
756 
757   // The LayoutAssignment pass may leave behind kCopy instructions which are
758   // duplicate or NOPs, so remove them with algebraic simplification and CSE.
759   AlgebraicSimplifierOptions options;
760   options.set_is_layout_sensitive(true);
761   options.set_enable_conv_operand_swap(false);
762   // "slow" minmax means we propagate nan.
763   options.set_minmax_propagate_nan(
764       !hlo_module->config().debug_options().xla_gpu_enable_fast_min_max());
765   pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
766 
767   // GemmRewriter assumes that all transposes are folded into gemms, but,
768   // since commit 7d529df, this is not always true at this point.
769   // Therefore, rerun transpose folding.
770   pipeline.AddPass<TransposeFolding>(CanFoldTransposeOperandIntoDot,
771                                      TransposeFolding::NeverFoldTranspose);
772   // Rewrite GEMMs into custom calls.
773   pipeline.AddPass<GemmRewriter>();
774 
775   // Rewrite GEMMs with broadcasted inputs as strided GEMMs.
776   pipeline.AddPass<GemmBroadcastFoldingRewriter>();
777 
778   // Run conversion again, to catch those matrix multiplications which were not
779   // rewritten into cuBLAS calls.
780   GpuBfloat16Support bf16(/*supports_matrix_multiplication=*/false,
781                           stream_exec);
782   pipeline.AddPass<BFloat16Normalization>(&bf16);
783 
784   // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization.
785   if (debug_options.xla_gpu_simplify_all_fp_conversions())
786     pipeline.AddPass<SimplifyFPConversions>();
787 
788   // Choose the fastest algorithm for each conv.
789   //
790   // We pick the algorithm before fusion so we can generate better HLO. After
791   // GpuConvRewriter, our convolutions are CustomCalls which return a
792   // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of
793   // scratch:
794   //
795   //   customcall = (f32[...], f32[0])
796   //   return gte(customcall, 0)
797   //
798   // The algorithm picker then chooses the best algorithm, and potentially
799   // increases the scratch space.  It replaces customcall with new_tuple,
800   // giving us the following:
801   //
802   //   new_customcall = (f32[...], f32[N])
803   //   new_tuple = tuple(gte(new_customcall, 0), constant f32[0])
804   //   return gte(new_tuple, 0)
805   //
806   // The new tuple and gte instructions then be simplified away, because
807   // nobody is expected to use the scratch value.
808   //
809   // However, if we were to run GpuConvAlgorithmPicker after fusion
810   // the gte(customcall, 0) would probably already be into a fusion node.  We
811   // can't simplify across HloComputation boundaries, so in this case we
812   // wouldn't be able to simplify away the new_tuple bits.
813   pipeline.AddPass<GpuConvAlgorithmPicker>(stream_exec, device_allocator);
814 
815   // Clean up new_tuple described above.
816   pipeline.AddPass<TupleSimplifier>();
817 
818   pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
819   TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
820 
821   return OkStatus();
822 }
823 
RunHloPasses(std::unique_ptr<HloModule> module,se::StreamExecutor * stream_exec,const CompileOptions & options)824 StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
825     std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
826     const CompileOptions& options) {
827   // We dump the post-optimization HLO in RunBackend so no need to dump it here.
828   XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses");
829   uint64_t start_usecs = tensorflow::Env::Default()->NowMicros();
830   tensorflow::profiler::TraceMe activity(
831       [&] { return absl::StrCat("HLO Transforms:", module->name()); },
832       tensorflow::profiler::TraceMeLevel::kInfo);
833   TF_RETURN_IF_ERROR(
834       OptimizeHloModule(module.get(), stream_exec, options.device_allocator));
835 
836   TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get()));
837 
838   uint64_t end_usecs = tensorflow::Env::Default()->NowMicros();
839 
840   // This won't record values for calls that error out (because if they error
841   // out we have no way of telling how far through the process we got).
842   RecordHloPassesDuration(end_usecs - start_usecs);
843 
844   return std::move(module);
845 }
846 
DummyCanShareBufferFunction(const HloInstruction *,const HloInstruction *,const ShapeIndex &)847 static std::optional<bool> DummyCanShareBufferFunction(const HloInstruction*,
848                                                        const HloInstruction*,
849                                                        const ShapeIndex&) {
850   return std::nullopt;
851 }
852 
AssignBuffers(const HloModule * hlo_module)853 StatusOr<std::unique_ptr<BufferAssignment>> GpuCompiler::AssignBuffers(
854     const HloModule* hlo_module) {
855   TF_ASSIGN_OR_RETURN(HloSchedule hlo_schedule,
856                       ScheduleGpuModule(hlo_module, pointer_size_));
857 
858   auto buffer_size_bytes_function =
859       [this](const BufferValue& buffer_value) -> int64_t {
860     return GetSizeOfShape(buffer_value.shape(), pointer_size_);
861   };
862 
863   TF_ASSIGN_OR_RETURN(
864       std::unique_ptr<BufferAssignment> assignment,
865       BufferAssigner::Run(
866           hlo_module, std::make_unique<SequentialHloOrdering>(hlo_schedule),
867           buffer_size_bytes_function,
868           /*color_alignment=*/
869           [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; },
870           /*allocate_buffers_for_constants=*/true,
871           /*colorer=*/BufferAssigner::DefaultColorer(),
872           /*must_not_live_out=*/{}, GetCanShareBuffer()));
873 
874   return std::move(assignment);
875 }
876 
877 #if XLA_ENABLE_XLIR
LowerToJitRt(mlir::ModuleOp mlir_module,llvm::StringRef entry_function_name,llvm::ArrayRef<int64_t> buffer_sizes,HloModule * hlo_module,se::StreamExecutor * stream_exec)878 static StatusOr<OwnedJitRtProgram> LowerToJitRt(
879     mlir::ModuleOp mlir_module, llvm::StringRef entry_function_name,
880     llvm::ArrayRef<int64_t> buffer_sizes, HloModule* hlo_module,
881     se::StreamExecutor* stream_exec) {
882   // Forward collective (NCCL) attributes for use by the lowering pipeline.
883   mlir::OpBuilder builder(mlir_module.getContext());
884   mlir::IntegerAttr replica_count_attr =
885       builder.getI64IntegerAttr(hlo_module->config().replica_count());
886   mlir::IntegerAttr num_partitions_attr =
887       builder.getI64IntegerAttr(hlo_module->config().num_partitions());
888   mlir::func::FuncOp func =
889       mlir_module.lookupSymbol<mlir::func::FuncOp>(entry_function_name);
890   func->setAttr("replica_count", replica_count_attr);
891   func->setAttr("num_partitions", num_partitions_attr);
892 
893   tensorflow::GpuBinaryOptions options;
894   if (stream_exec == nullptr) {
895     options = tensorflow::GpuBinaryOptions::DefaultGpuBinaryOptions();
896   } else {
897     options.platform_name = stream_exec->platform()->Name();
898     options.gpu_device_info = xla::gpu::GetGpuDeviceInfo(stream_exec);
899     options.cuda_compute_capability =
900         stream_exec->GetDeviceDescription().cuda_compute_capability();
901     options.rocm_compute_capability =
902         stream_exec->GetDeviceDescription().rocm_compute_capability();
903   }
904 
905   // Lower LMHLO operations to the JitRt compatible custom calls.
906   TF_RETURN_IF_ERROR(tensorflow::ConvertLmhloToJitRt(
907       mlir_module, {entry_function_name.data(), entry_function_name.size()},
908       buffer_sizes, options));
909   // Serialize module to pass it to GpuExecutable for compilation.
910   std::string serialized_module;
911   llvm::raw_string_ostream os(serialized_module);
912   mlir_module.print(os);
913 
914   // TODO(b/232033540): Pass MLIR module directly to JitRt to instantiate an
915   // executable, without forcing serialization.
916   return std::make_unique<GpuExecutable::JitRtProgram>(
917       entry_function_name.str(), os.str(), buffer_sizes.vec(),
918       hlo_module->config().debug_options());
919 }
920 #endif  // XLA_ENABLE_XLIR
921 
922 using OutputInfoMap =
923     absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>;
924 static Status GetMlirAllocationInfo(mlir::func::FuncOp func,
925                                     std::vector<BufferAllocation>* allocations,
926                                     OutputInfoMap* output_info,
927                                     Shape* output_shape,
928                                     EntryFunctionAttributes* entry_func_attrs);
929 
930 namespace {
931 // Removes all globals from the given module that are both uninitialized and
932 // have no uses within that module.
RemoveUnusedAndUninitializedGlobals(llvm::Module * llvm_module,const std::vector<GpuExecutable::ConstantInfo> & constants)933 void RemoveUnusedAndUninitializedGlobals(
934     llvm::Module* llvm_module,
935     const std::vector<GpuExecutable::ConstantInfo>& constants) {
936   for (const auto& info : constants) {
937     // Empty content means the constant is initialized in the LLVM IR, so we
938     // must not remove it.
939     if (!info.content.empty()) {
940       llvm::GlobalVariable* global =
941           llvm_module->getGlobalVariable(info.symbol_name);
942       CHECK(global != nullptr);
943       if (global->use_empty()) {
944         global->eraseFromParent();
945       }
946     }
947   }
948 }
949 }  // namespace
950 
951 struct CompileModuleResults {
952   std::unique_ptr<llvm::Module> llvm_module;
953   std::unique_ptr<BufferAssignment> buffer_assignment;
954   std::vector<BufferAllocation> allocations;
955   std::variant<OwnedThunkSequence, OwnedJitRtProgram> executable;
956   EntryFunctionAttributes entry_func_attrs;
957   std::vector<GpuExecutable::ConstantInfo> constants;
958   OutputInfoMap output_info;
959   Shape output_shape;
960   std::string module_name;
961 };
962 
963 // The order of `thunk_sequence` corresponds to
964 // `hlo_schedule->ThunkLaunchOrder()`.
CompileModuleToLlvmIrImpl(HloModule * hlo_module,llvm::LLVMContext * llvm_context,const std::string & target_triple,const std::string & data_layout,const std::string & platform_name,const se::Platform::Id platform_id,GpuDeviceInfo gpu_device_info,se::CudaComputeCapability cuda_compute_capability,se::RocmComputeCapability rocm_compute_capability,const HloDataflowAnalysis::CanShareBuffer & can_share_buffer_function,int pointer_size,CompileModuleResults * results,se::StreamExecutor * stream_exec=nullptr)965 static Status CompileModuleToLlvmIrImpl(
966     HloModule* hlo_module, llvm::LLVMContext* llvm_context,
967     const std::string& target_triple, const std::string& data_layout,
968     const std::string& platform_name, const se::Platform::Id platform_id,
969     GpuDeviceInfo gpu_device_info,
970     se::CudaComputeCapability cuda_compute_capability,
971     se::RocmComputeCapability rocm_compute_capability,
972     const HloDataflowAnalysis::CanShareBuffer& can_share_buffer_function,
973     int pointer_size, CompileModuleResults* results,
974     se::StreamExecutor* stream_exec = nullptr) {
975   results->llvm_module = std::make_unique<llvm::Module>("", *llvm_context);
976   results->llvm_module->setTargetTriple(target_triple);
977   results->llvm_module->setDataLayout(data_layout);
978 
979   TF_ASSIGN_OR_RETURN(HloSchedule hlo_schedule,
980                       ScheduleGpuModule(hlo_module, pointer_size));
981 
982   auto buffer_size_bytes_function =
983       [pointer_size](const BufferValue& buffer_value) -> int64_t {
984     return GetSizeOfShape(buffer_value.shape(), pointer_size);
985   };
986 
987   TF_ASSIGN_OR_RETURN(
988       results->buffer_assignment,
989       BufferAssigner::Run(
990           hlo_module, std::make_unique<SequentialHloOrdering>(hlo_schedule),
991           buffer_size_bytes_function,
992           /*color_alignment=*/
993           [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; },
994           /*allocate_buffers_for_constants=*/true,
995           /*colorer=*/BufferAssigner::DefaultColorer(),
996           /*must_not_live_out=*/{}, can_share_buffer_function));
997 
998   VLOG(1) << "Buffer Assignment Stats for " << hlo_module->name() << "\n"
999           << results->buffer_assignment->GetStats().ToString();
1000   DumpHloModuleIfEnabled(*hlo_module, *results->buffer_assignment,
1001                          absl::StrCat("sm_", cuda_compute_capability.ToString(),
1002                                       "_gpu_after_optimizations"));
1003 
1004   uint64_t start_usecs = tensorflow::Env::Default()->NowMicros();
1005   mlir::DialectRegistry registry;
1006   IrEmitterUnnested::GetDependentDialects(registry);
1007   mlir::MLIRContext mlir_context(registry);
1008   mlir::OwningOpRef<mlir::ModuleOp> mlir_module =
1009       mlir::ModuleOp::create(mlir::Builder(&mlir_context).getUnknownLoc());
1010 
1011   TF_RETURN_IF_ERROR(
1012       HloToLhloModule(*results->buffer_assignment, *hlo_module, *mlir_module));
1013 
1014   results->module_name = mlir::GetNameFromLoc(mlir_module->getLoc());
1015 
1016   if (DumpingEnabledForHloModule(*hlo_module)) {
1017     DumpToFileInDirOrStdout(*hlo_module, "lmhlo", mlir_module.get());
1018   }
1019 
1020   auto entry_function = mlir::cast<mlir::func::FuncOp>(
1021       mlir_module->lookupSymbol(hlo_module->entry_computation()->name()));
1022 
1023   TF_RETURN_IF_ERROR(GetMlirAllocationInfo(
1024       entry_function, &results->allocations, &results->output_info,
1025       &results->output_shape, &results->entry_func_attrs));
1026 
1027   if (hlo_module->config().debug_options().xla_gpu_enable_mlir_lowering()) {
1028     mlir::PassManager pm(&mlir_context);
1029     pm.addPass(mlir::createGpuFusionRewritePass());
1030     if (failed(pm.run(mlir_module.get()))) {
1031       return InternalError("Failed to run gpu-fusion-rewrite pass");
1032     }
1033   }
1034 
1035   IrEmitterContext ir_emitter_context(
1036       /*hlo_module=*/nullptr, /*buffer_assignment=*/nullptr, platform_name,
1037       gpu_device_info, cuda_compute_capability, rocm_compute_capability,
1038       &mlir_context, results->llvm_module.get());
1039 
1040   ir_emitter_context.set_allocations(results->allocations);
1041 
1042   TF_ASSIGN_OR_RETURN(
1043       auto ir_emitter,
1044       IrEmitterUnnested::Create(hlo_module->config(), &ir_emitter_context));
1045 
1046   {
1047     XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission");
1048 
1049     TF_RETURN_IF_ERROR(ir_emitter->EmitLmhloRegion(&entry_function.getBody()));
1050 
1051     bool supports_runtime_managed_constants =
1052         // TODO(b/218907125): Implement this feature for ROCm as well.
1053         platform_id != se::rocm::kROCmPlatformId &&
1054         hlo_module->config().debug_options().xla_gpu_enable_shared_constants();
1055     if (supports_runtime_managed_constants) {
1056       // Remove these globals from the generated code to indicate that XLA is
1057       // responsible for allocating and initializing them.
1058       RemoveUnusedAndUninitializedGlobals(ir_emitter_context.llvm_module(),
1059                                           ir_emitter_context.constants());
1060     }
1061 
1062     results->constants = std::move(ir_emitter_context.constants());
1063     uint64_t end_usecs = tensorflow::Env::Default()->NowMicros();
1064 
1065     // This won't record values for calls that error out (because if they error
1066     // out we have no way of telling how far through the process we got).
1067     RecordHloToLlvmDuration(end_usecs - start_usecs);
1068   }
1069 
1070 #if XLA_ENABLE_XLIR
1071   if (IsJitRtExecutableEnabled(hlo_module->config())) {
1072     std::vector<int64_t> buffer_sizes;
1073     llvm::transform(
1074         results->allocations, std::back_inserter(buffer_sizes),
1075         [](const BufferAllocation& allocation) { return allocation.size(); });
1076     TF_ASSIGN_OR_RETURN(results->executable,
1077                         LowerToJitRt(*mlir_module, entry_function.getName(),
1078                                      buffer_sizes, hlo_module, stream_exec));
1079     return OkStatus();
1080   }
1081 #endif  // XLA_ENABLE_XLIR
1082 
1083   results->executable = ir_emitter->ConsumeThunkSequence();
1084   return OkStatus();
1085 }
1086 
NullDiagnosticHandler(const llvm::DiagnosticInfo & diag_info,void * context)1087 static void NullDiagnosticHandler(const llvm::DiagnosticInfo& diag_info,
1088                                   void* context) {
1089   std::string error_string;
1090   llvm::raw_string_ostream string_printer(error_string);
1091   llvm::DiagnosticPrinterRawOStream diagnostic_printer(string_printer);
1092   diag_info.print(diagnostic_printer);
1093 
1094   VLOG(5) << error_string;
1095 }
1096 
1097 StatusOr<std::pair<std::string, std::vector<uint8_t>>>
CompileToTargetBinary(const HloModuleConfig & module_config,std::unique_ptr<llvm::Module> llvm_module,se::StreamExecutor * stream_exec,const CompileOptions & options,const HloModule * debug_module)1098 GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config,
1099                                    std::unique_ptr<llvm::Module> llvm_module,
1100                                    se::StreamExecutor* stream_exec,
1101                                    const CompileOptions& options,
1102                                    const HloModule* debug_module) {
1103   using BackendCompileResult = std::pair<std::string, std::vector<uint8_t>>;
1104 
1105   const auto compile_single_module =
1106       [this, stream_exec, &module_config, debug_module](
1107           llvm::Module* llvm_module, bool relocatable,
1108           std::optional<int> shard_number) -> StatusOr<BackendCompileResult> {
1109     {
1110       XLA_SCOPED_LOGGING_TIMER(
1111           "GpuCompiler::RunBackend - Running LLVM verifier");
1112 
1113       llvm_module->getContext().setDiagnosticHandlerCallBack(
1114           NullDiagnosticHandler, nullptr);
1115 
1116       std::string err;
1117       llvm::raw_string_ostream err_stream(err);
1118 
1119       // verifyModule() returns true if the module is broken.
1120       TF_RET_CHECK(!llvm::verifyModule(*llvm_module, &err_stream))
1121           << "Invalid LLVM IR before optimizations:\n"
1122           << err_stream.str()
1123           << "\nThis probably indicates a bug in the HLO -> LLVM IR "
1124              "lowering. Rerun with --xla_dump_to to get the IR"
1125           << (debug_module
1126                   ? absl::StrCat(" and looks for files with name containing: *",
1127                                  FilenameFor(*debug_module, "", ""), "*")
1128                   : ".");
1129     }
1130     GpuVersion gpu_version = GetGpuVersion(stream_exec);
1131     StatusOr<std::pair<std::string, std::vector<uint8_t>>> result =
1132         CompileTargetBinary(module_config, llvm_module, gpu_version,
1133                             stream_exec, relocatable, debug_module);
1134 
1135     if (!result.ok()) {
1136       return result;
1137     }
1138 
1139     const bool should_dump =
1140         DumpingEnabledForHloModule(debug_module ? debug_module->name() : "",
1141                                    module_config.debug_options());
1142 
1143     if (should_dump) {
1144       if (debug_module) {
1145         if (shard_number.has_value()) {
1146           llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module,
1147                                    /*optimized=*/true,
1148                                    std::to_string(*shard_number));
1149         } else {
1150           llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module,
1151                                    /*optimized=*/true);
1152         }
1153       } else {
1154         LOG(ERROR)
1155             << "Dumping is not implemented since the file name cannot be "
1156                "inferred. Please implement (potentially MLIR) module -> "
1157                "filename heuristic.";
1158       }
1159     }
1160 
1161     if (user_post_optimization_hook_) {
1162       user_post_optimization_hook_(*llvm_module);
1163     }
1164 
1165     // Write PTX to IR dump directory, if IR dumping was requested.
1166     if (should_dump) {
1167       absl::string_view ptx = result->first;
1168       if (debug_module) {
1169         if (shard_number.has_value()) {
1170           DumpToFileInDirOrStdout(*debug_module, "",
1171                                   std::to_string(*shard_number) + ".ptx", ptx);
1172         } else {
1173           DumpToFileInDirOrStdout(*debug_module, "", "ptx", ptx);
1174         }
1175       } else {
1176         LOG(ERROR)
1177             << "Dumping is not implemented since the file name cannot be "
1178                "inferred. Please implement (potentially MLIR) module -> "
1179                "filename heuristic.";
1180       }
1181     }
1182 
1183     return result;
1184   };
1185 
1186   tensorflow::thread::ThreadPool* thread_pool;
1187   std::optional<tensorflow::thread::ThreadPool> overriding_thread_pool;
1188   switch (
1189       module_config.debug_options().xla_gpu_force_compilation_parallelism()) {
1190     case 0:
1191       thread_pool = options.thread_pool;
1192       break;
1193     case 1:
1194       thread_pool = nullptr;
1195       break;
1196     default:
1197       overriding_thread_pool.emplace(
1198           tensorflow::Env::Default(), "",
1199           module_config.debug_options()
1200               .xla_gpu_force_compilation_parallelism());
1201       thread_pool = &*overriding_thread_pool;
1202       break;
1203   }
1204 
1205   if (!thread_pool) {
1206     return compile_single_module(llvm_module.get(), /*relocatable=*/false,
1207                                  /*shard_number=*/std::nullopt);
1208   }
1209 
1210   // Test whether LinkModules is supported.
1211   if (this->LinkModules(stream_exec, {}).status().code() ==
1212       tensorflow::error::Code::UNIMPLEMENTED) {
1213     return compile_single_module(llvm_module.get(), /*relocatable=*/false,
1214                                  /*shard_number=*/std::nullopt);
1215   }
1216 
1217   std::vector<std::unique_ptr<llvm::Module>> llvm_modules;
1218   int num_functions = 0;
1219   for (llvm::Function& func : llvm_module->functions()) {
1220     if (!func.isDeclaration() &&
1221         func.getLinkage() == llvm::GlobalValue::LinkageTypes::ExternalLinkage) {
1222       num_functions++;
1223     }
1224   }
1225 
1226   // Record the name of some constant global variables and their initializers.
1227   // We'll change the linkage type of these variables from external to internal
1228   // to ensure constant-folding works properly after calling llvm::SplitModule.
1229   llvm::DenseMap<llvm::StringRef, llvm::Constant*> const_initializer_map;
1230   for (llvm::GlobalVariable& gv : llvm_module->globals()) {
1231     if (gv.hasName() && gv.isConstant() && gv.hasInitializer() &&
1232         gv.hasExternalLinkage()) {
1233       llvm::Constant* initializer = gv.getInitializer();
1234       unsigned int num_elements = 0;
1235       if (auto* caz =
1236               llvm::dyn_cast<llvm::ConstantAggregateZero>(initializer)) {
1237         num_elements = caz->getElementCount().getFixedValue();
1238       } else if (auto* cds = llvm::dyn_cast<llvm::ConstantDataSequential>(
1239                      initializer)) {
1240         num_elements = cds->getNumElements();
1241       }
1242       if (num_elements > 0) {
1243         const_initializer_map[gv.getName()] = initializer;
1244       }
1245     }
1246   }
1247 
1248   llvm::SplitModule(
1249       *llvm_module,
1250       std::max<unsigned>(
1251           1, std::min<unsigned>(thread_pool->NumThreads(), num_functions)),
1252       [&](std::unique_ptr<llvm::Module> module) {
1253         // Change the linkage type of some global constant variables to internal
1254         for (llvm::GlobalVariable& gv : module->globals()) {
1255           if (gv.hasName() && gv.isConstant() && !gv.hasInitializer() &&
1256               const_initializer_map.count(gv.getName()) != 0) {
1257             gv.setInitializer(const_initializer_map[gv.getName()]);
1258             gv.setLinkage(llvm::GlobalValue::InternalLinkage);
1259           }
1260         }
1261         llvm_modules.push_back(std::move(module));
1262       },
1263       /*PreserveLocals=*/true);
1264 
1265   std::vector<StatusOr<BackendCompileResult>> compile_results(
1266       llvm_modules.size());
1267   tensorflow::BlockingCounter counter(llvm_modules.size());
1268   for (int i = 0; i < llvm_modules.size(); i++) {
1269     thread_pool->Schedule(
1270         [&compile_results, compile_single_module, i, &llvm_modules, &counter] {
1271           llvm::Module* original_module = llvm_modules[i].get();
1272           llvm::LLVMContext context;
1273           std::string buffer;
1274           llvm::raw_string_ostream error(buffer);
1275 
1276           std::unique_ptr<llvm::Module> new_llvm_module;
1277           // Switch to a new context by dumping and re-parsing LLVM IR. Each
1278           // thread has its own context to avoid race conditions.
1279           {
1280             std::string ir;
1281             {
1282               llvm::raw_string_ostream os(ir);
1283               original_module->print(os, nullptr);
1284             }
1285             llvm::SMDiagnostic err;
1286             new_llvm_module = llvm::parseAssemblyString(ir, err, context);
1287             if (!new_llvm_module) {
1288               std::string err_string;
1289               llvm::raw_string_ostream os(err_string);
1290               err.print(/*ProgName=*/nullptr, os, /*ShowColors=*/false);
1291               LOG(FATAL) << "Failed to parse IR: " << err_string;
1292             }
1293           }
1294 
1295           compile_results[i] = compile_single_module(
1296               new_llvm_module.get(), /*relocatable=*/true, /*shard_number=*/i);
1297           counter.DecrementCount();
1298         });
1299   }
1300   counter.Wait();
1301 
1302   std::string ptx_snippets;
1303   std::vector<std::vector<uint8_t>> submodule_compile_results;
1304   for (auto& maybe_result : compile_results) {
1305     TF_ASSIGN_OR_RETURN(auto result, maybe_result);
1306     if (result.second.empty()) {
1307       continue;
1308     }
1309     ptx_snippets += result.first;
1310     ptx_snippets += "\n";
1311     submodule_compile_results.push_back(result.second);
1312   }
1313 
1314   auto maybe_backend_result =
1315       this->LinkModules(stream_exec, std::move(submodule_compile_results));
1316   if (!maybe_backend_result.ok()) {
1317     LOG(ERROR) << "The CUDA linking API did not work. Please use "
1318                   "XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 to "
1319                   "bypass it, but expect to get longer compilation time due to "
1320                   "the lack of multi-threading.";
1321     return maybe_backend_result.status();
1322   }
1323 
1324   return std::make_pair(ptx_snippets, std::move(*maybe_backend_result));
1325 }
1326 
RunBackend(std::unique_ptr<HloModule> module,se::StreamExecutor * stream_exec,const CompileOptions & options)1327 StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
1328     std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
1329     const CompileOptions& options) {
1330   VLOG(1) << "Starting to compile HLO module " << module->name();
1331   XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend");
1332   std::string slow_compilation_msg =
1333       absl::StrCat("Compiling module ", module->name());
1334   auto slow_compile_alarm = SlowCompilationAlarm(slow_compilation_msg);
1335 
1336   TF_RET_CHECK(stream_exec != nullptr);
1337 
1338   llvm::LLVMContext llvm_context;
1339 
1340   GpuDeviceInfo gpu_device_info = GetGpuDeviceInfo(stream_exec);
1341 
1342   if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) {
1343     HloCostAnalysis::Options options{ShapeSizeBytesFunction()};
1344     options.set_bytes_per_second(
1345         stream_exec->GetDeviceDescription().memory_bandwidth());
1346     GpuHloCostAnalysis cost_analysis(options);
1347     TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis));
1348     VLOG(1) << "HLO memory read+written: "
1349             << tensorflow::strings::HumanReadableNumBytes(
1350                    cost_analysis.bytes_accessed());
1351     if (module->config().hlo_profiling_enabled()) {
1352       LOG(ERROR) << "--xla_hlo_profile for GPU is unsupported.";
1353     }
1354   }
1355 
1356   CompileModuleResults compile_module_results;
1357   TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
1358       module.get(), &llvm_context, target_triple_, data_layout_,
1359       stream_exec->platform()->Name(), stream_exec->platform()->id(),
1360       gpu_device_info,
1361       stream_exec->GetDeviceDescription().cuda_compute_capability(),
1362       stream_exec->GetDeviceDescription().rocm_compute_capability(),
1363       GetCanShareBuffer(), pointer_size_, &compile_module_results,
1364       stream_exec));
1365 
1366   if (user_pre_optimization_hook_) {
1367     user_pre_optimization_hook_(*compile_module_results.llvm_module);
1368   }
1369   std::string ir_module_string_before_opt;
1370   const bool embed_ir_in_executable =
1371       module->config().debug_options().xla_embed_ir_in_executable();
1372   if (embed_ir_in_executable) {
1373     ir_module_string_before_opt =
1374         llvm_ir::DumpModuleToString(*compile_module_results.llvm_module);
1375   }
1376 
1377   llvm_ir::DumpIrIfEnabled(*module, *compile_module_results.llvm_module,
1378                            /*optimized=*/false);
1379 
1380   using BackendCompileResult = std::pair<std::string, std::vector<uint8_t>>;
1381   TF_ASSIGN_OR_RETURN(
1382       BackendCompileResult backend_result,
1383       CompileToTargetBinary(module->config(),
1384                             std::move(compile_module_results.llvm_module),
1385                             stream_exec, options, module.get()));
1386   if (DumpingEnabledForHloModule(*module) &&
1387       std::holds_alternative<OwnedThunkSequence>(
1388           compile_module_results.executable)) {
1389     const ThunkSequence& thunk_sequence =
1390         *std::get<OwnedThunkSequence>(compile_module_results.executable);
1391     DumpToFileInDirOrStdout(*module, "", "thunk_sequence",
1392                             thunk_sequence.ToString());
1393   }
1394 
1395   auto buffer_assignment_proto = std::make_unique<BufferAssignmentProto>(
1396       compile_module_results.buffer_assignment->ToProto());
1397 
1398   // Make it shared to be captured in the following lambda.
1399   std::shared_ptr<const BufferAssignment> buffer_assignment(
1400       std::move(compile_module_results.buffer_assignment));
1401 
1402   GpuVersion gpu_version = GetGpuVersion(stream_exec);
1403   TF_ASSIGN_OR_RETURN(
1404       auto gpu_executable,
1405       GpuExecutable::Create(
1406           {std::move(backend_result.first), std::move(backend_result.second),
1407            gpu_version, std::move(compile_module_results.executable),
1408            compile_module_results.entry_func_attrs,
1409            std::move(compile_module_results.constants),
1410            std::move(compile_module_results.output_info),
1411            compile_module_results.module_name,
1412            compile_module_results.output_shape,
1413            std::move(compile_module_results.allocations),
1414            std::move(buffer_assignment_proto),
1415            [buffer_assignment] { return buffer_assignment->ToVerboseString(); },
1416            std::move(module)}));
1417   if (embed_ir_in_executable) {
1418     DCHECK_NE("", ir_module_string_before_opt);
1419     gpu_executable->set_ir_module_string(ir_module_string_before_opt);
1420   }
1421 
1422   // Dump computation proto state and buffer assignment for debug and test, if
1423   // dump or embed_ir_in_executable is enabled.
1424   if (embed_ir_in_executable ||
1425       DumpingEnabledForHloModule(gpu_executable->module())) {
1426     auto hlo_proto = std::make_unique<HloProto>();
1427     *hlo_proto->mutable_hlo_module() = gpu_executable->module().ToProto();
1428     *hlo_proto->mutable_buffer_assignment() = buffer_assignment->ToProto();
1429     gpu_executable->set_hlo_proto(std::move(hlo_proto));
1430   }
1431   gpu_executable->set_debug_info(buffer_assignment->GetStats().ToString());
1432   return static_cast<std::unique_ptr<Executable>>(std::move(gpu_executable));
1433 }
1434 
1435 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,const AotCompilationOptions & options)1436 GpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
1437                                 const AotCompilationOptions& options) {
1438 #if XLA_ENABLE_XLIR
1439   CHECK(options.PlatformId() == se::cuda::kCudaPlatformId);
1440   CHECK(options.executor() != nullptr);
1441   auto stream_exec = options.executor();
1442 
1443   std::vector<std::unique_ptr<HloModule>> modules =
1444       module_group->ConsumeModules();
1445   std::vector<std::unique_ptr<AotCompilationResult>> results;
1446 
1447   for (const auto& module : modules) {
1448     llvm::LLVMContext llvm_context;
1449     GpuDeviceInfo gpu_device_info = GetGpuDeviceInfo(stream_exec);
1450 
1451     // Compile the module
1452     CompileModuleResults compile_module_results;
1453     TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
1454         module.get(), &llvm_context, target_triple_, data_layout_,
1455         stream_exec->platform()->Name(), stream_exec->platform()->id(),
1456         gpu_device_info,
1457         stream_exec->GetDeviceDescription().cuda_compute_capability(),
1458         stream_exec->GetDeviceDescription().rocm_compute_capability(),
1459         GetCanShareBuffer(), pointer_size_, &compile_module_results));
1460     auto& compiled_executable = compile_module_results.executable;
1461 
1462     if (!std::holds_alternative<OwnedJitRtProgram>(compiled_executable)) {
1463       return InternalError("JitRtProgram not provided");
1464     }
1465 
1466     const auto& program = std::get<OwnedJitRtProgram>(compiled_executable);
1467 
1468     // Options for the default JitRt compilation pipeline.
1469     runtime::CompilationPipelineOptions copts;
1470 
1471     // Options for constructing JitRt JitExecutable.
1472     runtime::JitExecutable::Options opts;
1473     opts.specialization = runtime::JitExecutable::Specialization::kDisabled;
1474     opts.compiler.register_dialects =
1475         runtime::RegisterDefaultXlaRuntimeDialects;
1476 
1477     // Register JitRt Gpu runtime custom calls with the linker.
1478     opts.compiler.symbols_binding = runtime::ToSymbolsBinding(
1479         JitRtGpuCustomCalls(), PopulateXlaTypeIdNames);
1480 
1481     opts.compiler.create_compilation_pipeline = [copts](mlir::PassManager& pm) {
1482       runtime::CreateDefaultXlaRuntimeCompilationPipeline(pm, copts);
1483     };
1484 
1485     // Instantiate new JitExecutable from the MLIR source.
1486     auto jit_executable = runtime::JitExecutable::Instantiate(
1487         program->module, program->entry_point, opts);
1488     if (auto err = jit_executable.takeError())
1489       return InternalError("Failed to compile JitRt program: %s",
1490                            tfrt::StrCat(err));
1491 
1492     // For static shapes we can always serialize only the default executable.
1493     runtime::Executable& executable = jit_executable->DefaultExecutable().get();
1494 
1495     // Check if JitRt executable saved the compilation result.
1496     std::unique_ptr<llvm::MemoryBuffer> obj_file = executable.obj_file();
1497     if (!obj_file)
1498       return InternalError("JitRt executable didn't save the obj file");
1499 
1500     std::string data(obj_file->getBuffer().data(),
1501                      obj_file->getBuffer().size());
1502     results.emplace_back(std::make_unique<xla::gpu::JitRtAotCompilationResult>(
1503         module->ToProto(), data, program->module,
1504         compile_module_results.entry_func_attrs));
1505   }
1506   return std::move(results);
1507 #else
1508   return Unimplemented("");
1509 #endif  // XLA_ENABLE_XLIR
1510 }
1511 
ShapeSizeBytesFunction() const1512 HloCostAnalysis::ShapeSizeFunction GpuCompiler::ShapeSizeBytesFunction() const {
1513   // Capture just the pointer size, not the entire GpuCompiler object.
1514   return [pointer_size = pointer_size_](const Shape& shape) {
1515     return GetSizeOfShape(shape, pointer_size);
1516   };
1517 }
1518 
CompileModuleToLlvmIr(HloModule * hlo_module,llvm::LLVMContext * llvm_context,const std::string & target_triple,const std::string & data_layout,const std::string & platform_name,const se::Platform::Id platform_id,GpuDeviceInfo gpu_device_info,se::CudaComputeCapability cuda_compute_capability,se::RocmComputeCapability rocm_compute_capability,int pointer_size)1519 StatusOr<std::unique_ptr<llvm::Module>> CompileModuleToLlvmIr(
1520     HloModule* hlo_module, llvm::LLVMContext* llvm_context,
1521     const std::string& target_triple, const std::string& data_layout,
1522     const std::string& platform_name, const se::Platform::Id platform_id,
1523     GpuDeviceInfo gpu_device_info,
1524     se::CudaComputeCapability cuda_compute_capability,
1525     se::RocmComputeCapability rocm_compute_capability, int pointer_size) {
1526   CompileModuleResults results;
1527   TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
1528       hlo_module, llvm_context, target_triple, data_layout, platform_name,
1529       platform_id, gpu_device_info, cuda_compute_capability,
1530       rocm_compute_capability, DummyCanShareBufferFunction, pointer_size,
1531       &results));
1532   return std::move(results.llvm_module);
1533 }
1534 
1535 // Analyze the function signature to reconstruct a vector of BufferAllocation
1536 // objects, as well as other output information.
1537 //
1538 // This function also serves as a half-baked verifier for function arg
1539 // attributes, since a full verifier doens't exist yet.
GetMlirAllocationInfo(mlir::func::FuncOp func,std::vector<BufferAllocation> * allocations,OutputInfoMap * output_info,Shape * output_shape,EntryFunctionAttributes * entry_func_attrs)1540 static Status GetMlirAllocationInfo(mlir::func::FuncOp func,
1541                                     std::vector<BufferAllocation>* allocations,
1542                                     OutputInfoMap* output_info,
1543                                     Shape* output_shape,
1544                                     EntryFunctionAttributes* entry_func_attrs) {
1545   CHECK(allocations->empty());
1546   allocations->reserve(func.getNumArguments());
1547 
1548   std::vector<int64_t> buffer_sizes;
1549   for (int i = 0; i < func.getNumArguments(); i++) {
1550     mlir::BlockArgument arg = func.getArgument(i);
1551 
1552     TF_RET_CHECK(arg.getType().isa<mlir::ShapedType>());
1553     mlir::ShapedType type = arg.getType().cast<mlir::ShapedType>();
1554     TF_ASSIGN_OR_RETURN(auto element_type_bytes,
1555                         GetElementTypeBytes(type.getElementType()));
1556     size_t size = type.getNumElements() * element_type_bytes;
1557     buffer_sizes.push_back(size);
1558   }
1559 
1560   for (int i = 0; i < func.getNumArguments(); i++) {
1561     for (const mlir::NamedAttribute& attr : func.getArgAttrs(i)) {
1562       TF_RET_CHECK(attr.getName() == "lmhlo.params" ||
1563                    attr.getName() == "lmhlo.param_shape_index" ||
1564                    attr.getName() == "lmhlo.constant_name" ||
1565                    attr.getName() == "lmhlo.must_alias" ||
1566                    attr.getName() == "lmhlo.output_index");
1567     }
1568   }
1569 
1570   // Encode buffer parameter metadata in a proto for persisting, because BEF
1571   // doesn't persist function attributes.
1572   for (int i = 0; i < func.getNumArguments(); i++) {
1573     auto buffer = entry_func_attrs->add_buffers();
1574     if (auto param_attr = func.getArgAttr(i, "lmhlo.params")) {
1575       buffer->set_lmhlo_params_present(true);
1576       buffer->set_lmhlo_params(param_attr.cast<mlir::IntegerAttr>().getInt());
1577     }
1578     if (auto shape_index_attr = func.getArgAttr(i, "lmhlo.param_shape_index")) {
1579       auto param_shape_index = buffer->mutable_lmhlo_param_shape_index();
1580       for (const llvm::APInt& element :
1581            shape_index_attr.cast<mlir::DenseIntElementsAttr>()) {
1582         param_shape_index->add_indices(element.getSExtValue());
1583       }
1584     }
1585     if (auto constant_name_attr = func.getArgAttr(i, "lmhlo.constant_name")) {
1586       buffer->set_lmhlo_constant_name(
1587           constant_name_attr.cast<mlir::StringAttr>().str());
1588     }
1589     if (func.getArgAttr(i, "lmhlo.must_alias")) {
1590       buffer->set_lmhlo_must_alias(true);
1591     }
1592     if (auto output_index_attr = func.getArgAttr(i, "lmhlo.output_index")) {
1593       auto output_index = buffer->mutable_lmhlo_output_index();
1594       for (const llvm::APInt& element :
1595            output_index_attr.cast<mlir::DenseIntElementsAttr>()) {
1596         output_index->add_indices(element.getSExtValue());
1597       }
1598     }
1599   }
1600   entry_func_attrs->set_result_xla_shape(
1601       func->getAttrOfType<mlir::StringAttr>("result_xla_shape")
1602           .getValue()
1603           .str());
1604 
1605   return GpuExecutable::SetUpMlirAllocation(func, buffer_sizes, allocations,
1606                                             output_info, output_shape);
1607 }
1608 
CompileLmhloToExecutable(GpuCompiler * compiler,mlir::ModuleOp module,std::string module_name,const HloModuleConfig & module_config,const Compiler::CompileOptions & options,absl::string_view entry_function_name,se::StreamExecutor * stream_exec,std::unique_ptr<llvm::Module> llvm_module,IrEmitterContext * ir_emitter_context)1609 StatusOr<std::unique_ptr<Executable>> CompileLmhloToExecutable(
1610     GpuCompiler* compiler, mlir::ModuleOp module, std::string module_name,
1611     const HloModuleConfig& module_config,
1612     const Compiler::CompileOptions& options,
1613     absl::string_view entry_function_name, se::StreamExecutor* stream_exec,
1614     std::unique_ptr<llvm::Module> llvm_module,
1615     IrEmitterContext* ir_emitter_context) {
1616   mlir::func::FuncOp entry_function =
1617       mlir::cast<mlir::func::FuncOp>(module.lookupSymbol(llvm::StringRef(
1618           entry_function_name.data(), entry_function_name.size())));
1619 
1620   std::vector<BufferAllocation> allocations;
1621   OutputInfoMap output_info;
1622   Shape output_shape;
1623   EntryFunctionAttributes entry_func_attrs;
1624   TF_RETURN_IF_ERROR(GetMlirAllocationInfo(entry_function, &allocations,
1625                                            &output_info, &output_shape,
1626                                            &entry_func_attrs));
1627 
1628   TF_RET_CHECK(!allocations.empty());
1629 
1630   ir_emitter_context->set_allocations(allocations);
1631 
1632   TF_ASSIGN_OR_RETURN(auto ir_emitter, IrEmitterUnnested::Create(
1633                                            module_config, ir_emitter_context));
1634   TF_RETURN_IF_ERROR(ir_emitter->EmitLmhloRegion(&entry_function.getBody()));
1635 
1636   bool supports_runtime_managed_constants =
1637       // TODO(b/218907125): Implement this feature for ROCm as well.
1638       compiler->PlatformId() != se::rocm::kROCmPlatformId;
1639   if (supports_runtime_managed_constants) {
1640     // Remove these globals from the generated code to indicate that XLA is
1641     // responsible for allocating and initializing them.
1642     RemoveUnusedAndUninitializedGlobals(ir_emitter_context->llvm_module(),
1643                                         ir_emitter_context->constants());
1644   }
1645 
1646   auto thunk_sequence = ir_emitter->ConsumeThunkSequence();
1647 
1648   using BackendCompileResult = std::pair<std::string, std::vector<uint8_t>>;
1649   TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result,
1650                       compiler->CompileToTargetBinary(
1651                           module_config, std::move(llvm_module), stream_exec,
1652                           options, /*debug_module=*/nullptr));
1653 
1654   GpuVersion gpu_version = compiler->GetGpuVersion(stream_exec);
1655   return GpuExecutable::Create(
1656       {std::move(backend_result.first), std::move(backend_result.second),
1657        gpu_version, std::move(thunk_sequence), entry_func_attrs,
1658        std::move(ir_emitter_context->constants()), std::move(output_info),
1659        module_name, output_shape, std::move(allocations)});
1660 }
1661 
1662 }  // namespace gpu
1663 }  // namespace xla
1664