1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
17
18 #include <stdlib.h>
19
20 #include <atomic>
21 #include <functional>
22 #include <iterator>
23 #include <memory>
24 #include <string>
25 #include <utility>
26 #include <variant>
27 #include <vector>
28
29 #include "absl/strings/numbers.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/types/variant.h"
32 #include "llvm/AsmParser/Parser.h"
33 #include "llvm/Bitcode/BitcodeReader.h"
34 #include "llvm/Bitcode/BitcodeWriter.h"
35 #include "llvm/IR/DiagnosticInfo.h"
36 #include "llvm/IR/DiagnosticPrinter.h"
37 #include "llvm/IR/LLVMContext.h"
38 #include "llvm/IR/Module.h"
39 #include "llvm/IR/Verifier.h"
40 #include "llvm/Transforms/Utils/SplitModule.h"
41 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
42 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
43 #include "mlir/Dialect/GPU/Transforms/Passes.h" // from @llvm-project
44 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
45 #include "mlir/InitAllDialects.h" // from @llvm-project
46 #include "mlir/Pass/PassManager.h" // from @llvm-project
47 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
48 #include "mlir/Transforms/LocationSnapshot.h" // from @llvm-project
49 #include "mlir/Transforms/Passes.h" // from @llvm-project
50 #include "tensorflow/compiler/mlir/utils/name_utils.h"
51 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
52 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
53 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/gpu_passes.h"
54 #include "tensorflow/compiler/xla/protobuf_util.h"
55 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
56 #include "tensorflow/compiler/xla/service/all_gather_broadcast_reorder.h"
57 #include "tensorflow/compiler/xla/service/all_gather_combiner.h"
58 #include "tensorflow/compiler/xla/service/all_gather_decomposer.h"
59 #include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
60 #include "tensorflow/compiler/xla/service/all_reduce_contiguous.h"
61 #include "tensorflow/compiler/xla/service/all_reduce_folder.h"
62 #include "tensorflow/compiler/xla/service/all_reduce_reassociate.h"
63 #include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
64 #include "tensorflow/compiler/xla/service/async_collective_creator.h"
65 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
66 #include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
67 #include "tensorflow/compiler/xla/service/bitcast_decomposer.h"
68 #include "tensorflow/compiler/xla/service/bitcast_dtypes_expander.h"
69 #include "tensorflow/compiler/xla/service/broadcast_canonicalizer.h"
70 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
71 #include "tensorflow/compiler/xla/service/call_inliner.h"
72 #include "tensorflow/compiler/xla/service/collectives_schedule_linearizer.h"
73 #include "tensorflow/compiler/xla/service/comparison_expander.h"
74 #include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
75 #include "tensorflow/compiler/xla/service/conditional_simplifier.h"
76 #include "tensorflow/compiler/xla/service/convert_mover.h"
77 #include "tensorflow/compiler/xla/service/convolution_4d_expander.h"
78 #include "tensorflow/compiler/xla/service/convolution_pred_expander.h"
79 #include "tensorflow/compiler/xla/service/copy_insertion.h"
80 #include "tensorflow/compiler/xla/service/dot_decomposer.h"
81 #include "tensorflow/compiler/xla/service/dot_merger.h"
82 #include "tensorflow/compiler/xla/service/dump.h"
83 #include "tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h"
84 #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
85 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
86 #include "tensorflow/compiler/xla/service/eigh_expander.h"
87 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
88 #include "tensorflow/compiler/xla/service/gather_expander.h"
89 #include "tensorflow/compiler/xla/service/gather_simplifier.h"
90 #include "tensorflow/compiler/xla/service/gpu/alias_passthrough_params.h"
91 #include "tensorflow/compiler/xla/service/gpu/all_reduce_blueconnect.h"
92 #include "tensorflow/compiler/xla/service/gpu/fusion_bitcast_lift.h"
93 #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
94 #include "tensorflow/compiler/xla/service/gpu/gemm_broadcast_folding_rewriter.h"
95 #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
96 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
97 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h"
98 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
99 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
100 #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.h"
101 #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
102 #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
103 #include "tensorflow/compiler/xla/service/gpu/gpu_reduce_scatter_creator.h"
104 #include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h"
105 #include "tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h"
106 #include "tensorflow/compiler/xla/service/gpu/gpu_shape_verifier.h"
107 #include "tensorflow/compiler/xla/service/gpu/hlo_fusion_stats.h"
108 #include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h"
109 #include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h"
110 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
111 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
112 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
113 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
114 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
115 #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
116 #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h"
117 #include "tensorflow/compiler/xla/service/gpu/metrics.h"
118 #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
119 #include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
120 #include "tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h"
121 #include "tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h"
122 #include "tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h"
123 #include "tensorflow/compiler/xla/service/gpu/reduction_splitter.h"
124 #include "tensorflow/compiler/xla/service/gpu/runtime_intrinsics.h"
125 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
126 #include "tensorflow/compiler/xla/service/gpu/target_constants.h"
127 #include "tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.h"
128 #include "tensorflow/compiler/xla/service/gpu/variadic_op_splitter.h"
129 #include "tensorflow/compiler/xla/service/hlo_computation.h"
130 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
131 #include "tensorflow/compiler/xla/service/hlo_cse.h"
132 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
133 #include "tensorflow/compiler/xla/service/hlo_dce.h"
134 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
135 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
136 #include "tensorflow/compiler/xla/service/hlo_parser.h"
137 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
138 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
139 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
140 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
141 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
142 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
143 #include "tensorflow/compiler/xla/service/layout_normalization.h"
144 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
145 #include "tensorflow/compiler/xla/service/logistic_expander.h"
146 #include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h"
147 #include "tensorflow/compiler/xla/service/operand_upcaster.h"
148 #include "tensorflow/compiler/xla/service/optimization_barrier_expander.h"
149 #include "tensorflow/compiler/xla/service/qr_expander.h"
150 #include "tensorflow/compiler/xla/service/real_imag_expander.h"
151 #include "tensorflow/compiler/xla/service/reduce_decomposer.h"
152 #include "tensorflow/compiler/xla/service/reduce_scatter_combiner.h"
153 #include "tensorflow/compiler/xla/service/reshape_decomposer.h"
154 #include "tensorflow/compiler/xla/service/reshape_mover.h"
155 #include "tensorflow/compiler/xla/service/result_caster.h"
156 #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
157 #include "tensorflow/compiler/xla/service/rng_expander.h"
158 #include "tensorflow/compiler/xla/service/scatter_simplifier.h"
159 #include "tensorflow/compiler/xla/service/sharding_propagation.h"
160 #include "tensorflow/compiler/xla/service/sharding_remover.h"
161 #include "tensorflow/compiler/xla/service/simplify_fp_conversions.h"
162 #include "tensorflow/compiler/xla/service/slice_sinker.h"
163 #include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
164 #include "tensorflow/compiler/xla/service/sort_simplifier.h"
165 #include "tensorflow/compiler/xla/service/spmd/stateful_rng_spmd_partitioner.h"
166 #include "tensorflow/compiler/xla/service/stable_sort_expander.h"
167 #include "tensorflow/compiler/xla/service/transpose_folding.h"
168 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
169 #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
170 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
171 #include "tensorflow/compiler/xla/service/while_loop_trip_count_annotator.h"
172 #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h"
173 #include "tensorflow/compiler/xla/status_macros.h"
174 #include "tensorflow/compiler/xla/types.h"
175 #include "tensorflow/compiler/xla/util.h"
176 #include "tensorflow/core/lib/core/status.h"
177 #include "tensorflow/core/platform/blocking_counter.h"
178 #include "tensorflow/core/platform/casts.h"
179 #include "tensorflow/core/platform/env.h"
180 #include "tensorflow/core/platform/logging.h"
181 #include "tensorflow/core/platform/regexp.h"
182 #include "tensorflow/core/platform/statusor.h"
183 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
184 #include "tensorflow/core/platform/threadpool.h"
185 #include "tensorflow/core/profiler/lib/traceme.h"
186 #include "tensorflow/core/util/env_var.h"
187
188 #if XLA_ENABLE_XLIR
189 #include "tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/pass_utils.h"
190 #include "tensorflow/compiler/xla/mlir/transforms/runtime/compilation_pipeline.h"
191 #include "tensorflow/compiler/xla/runtime/jit_executable.h"
192 #include "tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.h"
193 #endif // XLA_ENABLE_XLIR
194
195 namespace xla {
196 namespace gpu {
197 namespace {
198
199 class GpuBfloat16Support : public BFloat16Support {
200 public:
GpuBfloat16Support(bool supports_matrix_multiplication,se::StreamExecutor * stream_exec)201 explicit GpuBfloat16Support(bool supports_matrix_multiplication,
202 se::StreamExecutor* stream_exec)
203 : supports_matrix_multiplication_(supports_matrix_multiplication),
204 stream_exec_(stream_exec) {}
205
SupportsBF16Operand(const HloInstruction & hlo,int64_t operand_index) const206 bool SupportsBF16Operand(const HloInstruction& hlo,
207 int64_t operand_index) const override {
208 return BFloat16Support::SupportsBF16Operand(hlo, operand_index) ||
209 IsSupported(hlo);
210 }
211
212 // Returns whether the backend supports BF16 output for the HLO instruction.
SupportsBF16Output(const HloInstruction & hlo) const213 bool SupportsBF16Output(const HloInstruction& hlo) const override {
214 return BFloat16Support::SupportsBF16Output(hlo) || IsSupported(hlo);
215 }
216
217 private:
IsSupported(const HloInstruction & hlo) const218 bool IsSupported(const HloInstruction& hlo) const {
219 switch (hlo.opcode()) {
220 // Collective ops.
221 case HloOpcode::kAllGather:
222 case HloOpcode::kAllReduce:
223 case HloOpcode::kAllReduceStart:
224 case HloOpcode::kAllReduceDone:
225 case HloOpcode::kAllToAll:
226 case HloOpcode::kCollectivePermute:
227 case HloOpcode::kReduceScatter:
228 // Data movement only ops.
229 case HloOpcode::kBroadcast:
230 case HloOpcode::kConcatenate:
231 case HloOpcode::kCopy:
232 case HloOpcode::kDynamicSlice:
233 case HloOpcode::kDynamicUpdateSlice:
234 case HloOpcode::kGather:
235 case HloOpcode::kPad:
236 case HloOpcode::kReshape:
237 case HloOpcode::kReverse:
238 case HloOpcode::kScatter:
239 case HloOpcode::kSelect:
240 case HloOpcode::kSelectAndScatter:
241 case HloOpcode::kSlice:
242 case HloOpcode::kTranspose:
243 // Other special ops.
244 case HloOpcode::kBitcast:
245 return true;
246 case HloOpcode::kConvolution:
247 return IsConvBF16Supported();
248 default:
249 return supports_matrix_multiplication_ &&
250 gpu::IsMatrixMultiplication(hlo);
251 }
252 }
253
IsConvBF16Supported() const254 bool IsConvBF16Supported() const {
255 if (se::dnn::DnnSupport* dnn = stream_exec_->AsDnn()) {
256 se::port::StatusOr<se::dnn::VersionInfo> cudnn_version =
257 dnn->GetVersion();
258 return cudnn_version.ok() &&
259 (cudnn_version->major_version() > 8 ||
260 (cudnn_version->major_version() == 8 &&
261 cudnn_version->minor_version() >= 2)) &&
262 stream_exec_->GetDeviceDescription()
263 .cuda_compute_capability()
264 .IsAtLeast(se::CudaComputeCapability::AMPERE);
265 }
266 return false;
267 }
268
269 bool supports_matrix_multiplication_;
270 se::StreamExecutor* stream_exec_;
271 };
272
GetSizeOfShape(const Shape & shape,int pointer_size)273 int64_t GetSizeOfShape(const Shape& shape, int pointer_size) {
274 if (shape.is_static() || shape.IsTuple()) {
275 return ShapeUtil::ByteSizeOf(shape, pointer_size);
276 }
277 // Each dynamic dimension size is represented as a S32.
278 int64_t metadata_size = sizeof(int32_t) * shape.dimensions_size();
279 return ShapeUtil::ByteSizeOf(shape, pointer_size) + metadata_size;
280 }
281
ConvIsLowerable(HloInstruction * conv)282 bool ConvIsLowerable(HloInstruction* conv) {
283 return conv_matchers::CanImplementAsGpuForwardConv(conv) ||
284 std::get<0>(conv_matchers::MatchBackwardFilter(conv)) ||
285 std::get<0>(conv_matchers::MatchBackwardInput(conv));
286 }
287
288 } // end anonymous namespace
289
290 using OwnedThunkSequence = GpuExecutable::OwnedThunkSequence;
291 using OwnedJitRtProgram = GpuExecutable::OwnedJitRtProgram;
292
LoadExecutable(Compiler * compiler,se::StreamExecutor * executor) const293 StatusOr<std::unique_ptr<Executable>> JitRtAotCompilationResult::LoadExecutable(
294 Compiler* compiler, se::StreamExecutor* executor) const {
295 TF_ASSIGN_OR_RETURN(
296 HloModuleConfig hlo_module_config,
297 HloModule::CreateModuleConfigFromProto(
298 jitrt_executable_.hlo_module_proto(), GetDebugOptionsFromFlags()));
299 TF_ASSIGN_OR_RETURN(
300 std::unique_ptr<HloModule> hlo_module,
301 HloModule::CreateFromProto(jitrt_executable_.hlo_module_proto(),
302 hlo_module_config));
303 auto gpu_compiler = tensorflow::down_cast<GpuCompiler*>(compiler);
304 return GpuExecutable::LoadFromObjFile(
305 std::move(hlo_module), jitrt_executable_.obj_file(),
306 jitrt_executable_.mlir_module(), jitrt_executable_.entry_func_attrs(),
307 GetDebugOptionsFromFlags(), gpu_compiler->GetGpuVersion(executor),
308 executor);
309 }
310
GpuCompiler(se::Platform::Id platform_id,const char * target_triple,const char * data_layout)311 GpuCompiler::GpuCompiler(se::Platform::Id platform_id,
312 const char* target_triple, const char* data_layout)
313 : platform_id_(platform_id),
314 target_triple_(target_triple),
315 data_layout_(data_layout),
316 pointer_size_(llvm::DataLayout(data_layout)
317 .getPointerSize(0 /* default address space */)) {}
318
319 namespace {
320 // Adds the HloVerifier for GPU to the given pipeline.
AddHloVerifier(HloPassPipeline * pipeline,HloVerifierOpts && opts={},bool debug_only=false)321 void AddHloVerifier(HloPassPipeline* pipeline, HloVerifierOpts&& opts = {},
322 bool debug_only = false) {
323 std::unique_ptr<TargetVerifierMetadata> verifier_metadata =
324 std::make_unique<GpuVerifierMetadata>(std::move(opts));
325 if (debug_only) {
326 pipeline->AddInvariantCheckerDebug<HloVerifier>(
327 std::move(verifier_metadata), "hlo verifier (debug)");
328 } else {
329 pipeline->AddInvariantChecker<HloVerifier>(std::move(verifier_metadata),
330 "hlo verifier");
331 }
332 }
333 } // namespace
334
335 // Runs optimization passes on the given HLO module.
OptimizeHloModule(HloModule * hlo_module,se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * device_allocator)336 Status GpuCompiler::OptimizeHloModule(
337 HloModule* hlo_module, se::StreamExecutor* stream_exec,
338 se::DeviceMemoryAllocator* device_allocator) {
339 const DebugOptions& debug_options = hlo_module->config().debug_options();
340
341 AlgebraicSimplifierOptions layout_insensitive_algsimp_opts({},
342 ConvIsLowerable);
343 // "slow" minmax means we propagate nan.
344 layout_insensitive_algsimp_opts.set_minmax_propagate_nan(
345 !debug_options.xla_gpu_enable_fast_min_max());
346
347 const se::Platform* platform = stream_exec->platform();
348 if (platform->Name() == "ROCM") {
349 // SwapConvOperands does not yet work on ROCM
350 layout_insensitive_algsimp_opts.set_enable_conv_operand_swap(false);
351 }
352
353 if (hlo_module->config().use_spmd_partitioning()) {
354 HloPassPipeline spmd_pipeline("spmd-partitioner");
355 AddHloVerifier(&spmd_pipeline);
356 const int64_t num_partitions = hlo_module->config().num_partitions();
357 if (num_partitions > 1) {
358 // Run some IR cleanup passes before running the SPMD partitioning
359 // passes.
360 spmd_pipeline.AddPass<CallInliner>();
361 spmd_pipeline.AddPass<ZeroSizedHloElimination>();
362 spmd_pipeline.AddPass<ConditionalCanonicalizer>();
363
364 HloPassPipeline& spmd_simplify =
365 spmd_pipeline.AddPass<HloPassFix<HloPassPipeline>>("spmd-simplify");
366
367 spmd_simplify.AddPass<AlgebraicSimplifier>(
368 layout_insensitive_algsimp_opts);
369
370 spmd_simplify.AddPass<SortSimplifier>();
371 spmd_simplify.AddPass<TupleSimplifier>();
372 spmd_simplify.AddPass<ScatterSimplifier>();
373 spmd_simplify.AddPass<ScatterExpander>(
374 ScatterExpander::kEliminateSimpleScatters);
375 spmd_simplify.AddPass<GatherSimplifier>();
376 spmd_simplify.AddPass<GatherExpander>(
377 GatherExpander::kEliminateSimpleGathers);
378 spmd_simplify.AddPass<WhileLoopConstantSinking>();
379 spmd_simplify.AddPass<WhileLoopSimplifier>();
380
381 spmd_simplify.AddPass<ReshapeMover>();
382 spmd_simplify.AddPass<HloConstantFolding>();
383 spmd_simplify.AddPass<ConditionalSimplifier>();
384 spmd_simplify.AddPass<HloDCE>();
385
386 spmd_pipeline.AddPass<ShardingPropagation>(
387 /*is_spmd=*/true, /*propagate_metadata=*/false,
388 hlo_module->config().allow_spmd_sharding_propagation_to_output());
389 spmd_pipeline.AddPass<spmd::StatefulRngSpmdPartitioner>(
390 num_partitions, hlo_module->config().replica_count());
391 } else {
392 // Remove redundant sharding ops when partition_count == 1.
393 spmd_pipeline.AddPass<ShardingRemover>();
394 spmd_pipeline.AddPass<HloDCE>();
395 }
396 TF_RETURN_IF_ERROR(spmd_pipeline.Run(hlo_module).status());
397 }
398
399 {
400 HloPassPipeline pipeline("optimization");
401 AddHloVerifier(&pipeline);
402 pipeline.AddPass<AllToAllDecomposer>();
403
404 HloPredicate upcaster_filter = [&](const HloInstruction* instr) {
405 return !stream_exec->GetDeviceDescription()
406 .cuda_compute_capability()
407 .IsAtLeast(se::CudaComputeCapability::VOLTA) ||
408 !gpu::IsMatrixMultiplication(*instr);
409 };
410
411 pipeline.AddPass<OperandUpcaster>(upcaster_filter);
412 pipeline.AddPass<ResultCaster>(upcaster_filter);
413
414 // Expand random number generation.
415 pipeline.AddPass<RngExpander>();
416 pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
417
418 // Comparison total order expander
419 pipeline.AddPass<ComparisonExpander>();
420
421 // Remove zero-sized HLO from the input so that other passes don't have to
422 // handle it.
423 pipeline.AddPass<ZeroSizedHloElimination>();
424
425 if (debug_options.xla_gpu_deterministic_ops()) {
426 // Scatter is nondeterministic, so eliminate all Scatters.
427 pipeline.AddPass<ScatterExpander>(ScatterExpander::kEliminateAllScatters);
428 } else {
429 // Only Scatters unsupported on XLA:GPU are eliminated.
430 pipeline.AddPass<GpuScatterExpander>();
431 }
432 // TODO(phawkins): replace QR and Eigh decompositions with calls to
433 // cuSOLVER.
434 pipeline.AddPass<QrExpander>();
435 pipeline.AddPass<EighExpander>();
436
437 pipeline.AddPass<DynamicIndexSplitter>();
438
439 // TODO(b/64094172): make Call work on GPU instead of inlining.
440 pipeline.AddPass<CallInliner>();
441
442 pipeline.AddPass<DotDecomposer>();
443
444 pipeline.AddPass<Convolution4DExpander>();
445
446 // Replace PRED convolutions with F16.
447 pipeline.AddPass<ConvolutionPredExpander>();
448
449 // Expand the sort op to support stable sorting if required.
450 pipeline.AddPass<StableSortExpander>();
451
452 GpuBfloat16Support bf16(/*supports_matrix_multiplication=*/true,
453 stream_exec);
454 pipeline.AddPass<BFloat16Normalization>(&bf16);
455
456 // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization.
457 if (debug_options.xla_gpu_simplify_all_fp_conversions())
458 pipeline.AddPass<SimplifyFPConversions>();
459
460 pipeline.AddPass<BatchNormExpander>(
461 /*rewrite_training_op=*/true,
462 /*rewrite_inference_op=*/true,
463 /*rewrite_grad_op=*/true);
464
465 pipeline.AddPass<LogisticExpander>(
466 /*expansion_type=*/LogisticExpansionType::kExp);
467 pipeline.AddPass<ConditionalCanonicalizer>();
468 pipeline.AddPass<DynamicDimensionSimplifier>();
469
470 DynamicPadderOptions dynamic_padder_options;
471
472 switch (hlo_module->config().debug_options().xla_gpu_shape_checks()) {
473 case DebugOptions::IGNORE:
474 dynamic_padder_options.shape_check_mode =
475 DynamicDimensionInference::ShapeCheckMode::kIgnore;
476 break;
477 case DebugOptions::RUNTIME: {
478 dynamic_padder_options.shape_check_mode =
479 DynamicDimensionInference::ShapeCheckMode::kRuntime;
480 dynamic_padder_options.assertion_generator = [&](HloInstruction* inst) {
481 auto created = Cast<HloCustomCallInstruction>(
482 inst->parent()->AddInstruction(HloInstruction::CreateCustomCall(
483 ShapeUtil::MakeTokenShape(), {inst},
484 kXlaGpuAssertCustomCallTag,
485 "Buffers have different size at runtime",
486 API_VERSION_STATUS_RETURNING)));
487 created->set_custom_call_has_side_effect(true);
488 };
489 break;
490 }
491 case DebugOptions::COMPILE_TIME:
492 dynamic_padder_options.shape_check_mode =
493 DynamicDimensionInference::ShapeCheckMode::kCompileTime;
494 break;
495 default:
496 LOG(FATAL) << "Unreachable";
497 }
498
499 pipeline.AddPass<DynamicPadder>(dynamic_padder_options);
500
501 // Build simplification pipeline. The passes in here are run to a fixed
502 // point.
503 [&, &pipeline =
504 pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification")] {
505 AddHloVerifier(&pipeline, HloVerifierOpts{}, /*debug_only=*/true);
506
507 // BatchNormExpander can create zero-sized ops, so zero-sized HLO
508 // elimination has to come after that pass.
509 pipeline.AddPass<ZeroSizedHloElimination>();
510
511 pipeline.AddPass<GatherSimplifier>();
512 pipeline.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
513 pipeline.AddPass<ScatterSimplifier>();
514 pipeline.AddPass<ScatterExpander>(
515 ScatterExpander::kEliminateSimpleScatters);
516 pipeline.AddPass<AlgebraicSimplifier>(layout_insensitive_algsimp_opts);
517 pipeline.AddPass<BitcastDtypesExpander>();
518 // AlgebraicSimplifier may add contracting dimensions to a dot.
519 pipeline.AddPass<DotDecomposer>();
520 // Only merge "smallish" dots. This threshold was not set carefully, but
521 // so far we know that 1mb is too small.
522 pipeline.AddPass<DotMerger>(/*max_size_to_merge=*/int64_t{16} << 20);
523 pipeline.AddPass<SortSimplifier>();
524 pipeline.AddPass<TupleSimplifier>();
525 pipeline.AddPass<WhileLoopConstantSinking>();
526 pipeline.AddPass<WhileLoopSimplifier>();
527
528 // TODO(b/134075051): Re-enable after b/134075051 is fixed.
529 // pipeline.AddPass<SliceSinker>();
530
531 pipeline.AddPass<ReshapeMover>();
532 pipeline.AddPass<HloConstantFolding>();
533 pipeline.AddPass<ConditionalSimplifier>();
534 pipeline.AddPass<RealImagExpander>();
535 pipeline.AddPass<TransposeFolding>(CanFoldTransposeOperandIntoDot);
536 pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
537 pipeline.AddPass<HloDCE>();
538 }();
539
540 // ConvertMover and ReshapeMover fight with each other: ConvertMover wants
541 // to move some converts down the graph, but ReshapeMover wants to move them
542 // up the graph. As a compromise, let ReshapeMover run to a fixed point,
543 // and then run ConvertMover + algsimp to a fixed point.
544 [&, &pipeline =
545 pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification-2")] {
546 pipeline.AddPass<ConvertMover>();
547 pipeline.AddPass<AlgebraicSimplifier>(layout_insensitive_algsimp_opts);
548 }();
549
550 // Run WhileLoopTripCountAnnotator at the end of the simplification
551 // pipeline, before layout assignment and fusion. This pass does some
552 // pattern-matching on while bodies/conditions, and this is where the HLO is
553 // "nicest".
554 //
555 // It's important that we don't make semantic changes (e.g. unrolling) to
556 // any `while` loops after this point, because otherwise the trip-count
557 // annotations added by this pass may not be correct after the
558 // modifications.
559 pipeline.AddPass<WhileLoopTripCountAnnotator>();
560 TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
561 }
562
563 // Optimize collectives generated by SPMD partitioning. Enable these passes
564 // otherwise as well so that all collectives can get these optimizations.
565 {
566 HloPassPipeline collectives_pipeline("collective-optimizations");
567 collectives_pipeline.AddPass<AllReduceFolder>();
568 collectives_pipeline.AddPass<ReduceScatterCreator>();
569 collectives_pipeline.AddPass<AllReduceReassociate>();
570
571 // Run algebraic simplifier to reshape(broadcast) into a broadcast when
572 // the reshape is just adding a unit dimension. This will help with the
573 // AllGatherBroadcastReorder pass.
574 collectives_pipeline.AddPass<AlgebraicSimplifier>(
575 layout_insensitive_algsimp_opts);
576
577 collectives_pipeline.AddPass<AllGatherBroadcastReorder>();
578 TF_RETURN_IF_ERROR(collectives_pipeline.Run(hlo_module).status());
579 }
580
581 // Run target-specific HLO optimization passes for convolution
582 // canonicalization.
583 TF_RETURN_IF_ERROR(OptimizeHloConvolutionCanonicalization(
584 hlo_module, stream_exec, device_allocator));
585
586 {
587 // Run layout assignment in a separate pipeline from
588 // "post-layout-assignment" because we want everything after layout
589 // assignment to have a layout-sensitive invariant-checker, but
590 // HloPassPipeline also runs its invariant checker before any passes are
591 // run, meaning, the pipeline that contains layout assignment cannot contain
592 // a layout-sensitive verifier!
593 HloPassPipeline pipeline("layout assignment");
594 // Layout assignment uses alias analysis, which requires the call graph to
595 // be flattened.
596 pipeline.AddPass<FlattenCallGraph>();
597 ChannelLayoutConstraints layout_constraints;
598 pipeline.AddPass<GpuLayoutAssignment>(
599 hlo_module->mutable_entry_computation_layout(), stream_exec,
600 &layout_constraints);
601 TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
602 }
603
604 // Run target-specific HLO optimization passes after layout assignment.
605 TF_RETURN_IF_ERROR(OptimizeHloPostLayoutAssignment(hlo_module, stream_exec,
606 device_allocator));
607
608 {
609 HloPassFix<HloPassPipeline> fusion("fusion");
610 // We try to split variadic ops with many parameters into several such ops
611 // to avoid exceeding the parameter space.
612 fusion.AddPass<VariadicOpSplitter>();
613 AddHloVerifier(
614 &fusion,
615 HloVerifierOpts{}.MakeLayoutSensitive().WithInstructionCanChangeLayout(
616 LayoutAssignment::InstructionCanChangeLayout),
617 /*debug_only=*/true);
618 fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
619 fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
620 fusion.AddPass<FusionMerger>();
621 fusion.AddPass<GpuMultiOutputFusion>();
622 fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
623 /*only_fusion_computations=*/true);
624 fusion.AddPass<HloDCE>();
625 TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
626 }
627
628 {
629 HloPassFix<HloPassPipeline> horizontal_fusion("horizontal fusion");
630 horizontal_fusion.AddPass<GpuHorizontalLoopFusion>();
631 horizontal_fusion.AddPass<GpuHorizontalInputFusion>();
632 // FusionBitcastLift must be after InstructionFusion, as it undoes
633 // part of it.
634 // TODO(b/209005695) Renable once the bug is fixed.
635 // horizontal_fusion.AddPass<FusionBitcastLift>();
636 horizontal_fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
637 /*only_fusion_computations=*/true);
638 horizontal_fusion.AddPass<HloDCE>();
639 TF_RETURN_IF_ERROR(horizontal_fusion.Run(hlo_module).status());
640 }
641
642 if (VLOG_IS_ON(2)) {
643 HloFusionStatsVisitor stats;
644 TF_RETURN_IF_ERROR(hlo_module->entry_computation()->Accept(&stats));
645 VLOG(2) << stats.ToString();
646 }
647
648 {
649 HloPassPipeline pipeline("post-fusion optimization");
650 pipeline.AddPass<AllGatherCombiner>(
651 /*combine_threshold_in_bytes=*/1024 * 1024 * 1024,
652 /*combine_threshold_count=*/256);
653 pipeline.AddPass<AllReduceCombiner>(
654 debug_options.xla_gpu_all_reduce_combine_threshold_bytes(),
655 /*combine_threshold_count=*/256);
656 pipeline.AddPass<ReduceScatterCombiner>(
657 /*combine_threshold_in_bytes=*/30 * 1024 * 1024,
658 /*combine_threshold_count=*/256);
659
660 if (debug_options.xla_gpu_all_reduce_contiguous()) {
661 pipeline.AddPass<AllReduceContiguous>();
662 }
663
664 int32_t blueconnect_num_devices_per_host =
665 debug_options.xla_gpu_all_reduce_blueconnect_num_devices_per_host();
666 if (blueconnect_num_devices_per_host > 0) {
667 pipeline.AddPass<AllReduceBlueConnect>(blueconnect_num_devices_per_host);
668 }
669
670 if (debug_options.xla_gpu_enable_async_all_reduce()) {
671 AsyncCollectiveCreator::CollectiveCreatorConfig config;
672 config.convert_all_reduce = [](const HloInstruction*) { return true; };
673 pipeline.AddPass<AsyncCollectiveCreator>(std::move(config));
674 }
675
676 pipeline.AddPass<CollectivesScheduleLinearizer>();
677
678 AlgebraicSimplifierOptions options = layout_insensitive_algsimp_opts;
679 options.set_is_layout_sensitive(true);
680 pipeline.AddPass<AlgebraicSimplifier>(options);
681 pipeline.AddPass<OptimizationBarrierExpander>();
682 pipeline.AddPass<BitcastDecomposer>();
683 pipeline.AddPass<TupleSimplifier>();
684
685 TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
686 }
687
688 return OkStatus();
689 }
690
691 // Modifies the given HLO module so that it will be accepted by IrEmitter.
692 // Unlike optimization passes, the passes are necessary for correctness.
PrepareHloModuleForIrEmitting(HloModule * hlo_module)693 Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
694 // In some cases, we have to place the result of an instruction in a temporary
695 // buffer. For instance, the buffer that holds an external parameter is
696 // assumed immutable at this point, and should not be reused for output
697 // (b/27180329). Therefore, in that case, we set the output to be a copy of
698 // the parameter.
699 HloPassPipeline pipeline("GPU-ir-emit-prepare");
700 AddHloVerifier(
701 &pipeline,
702 HloVerifierOpts{}.MakeLayoutSensitive().WithInstructionCanChangeLayout(
703 LayoutAssignment::InstructionCanChangeLayout),
704 /*debug_only=*/true);
705
706 // Copy insertion should be performed immediately before IR emission to avoid
707 // inserting unnecessary copies (later pass adds an instruction which
708 // materializes the value) or missing a necessary copy (later pass removes an
709 // instruction which materializes a value). DCE must be run immediately before
710 // (and sometime after) copy insertion, to avoid dead code from interfering
711 // with the rewrites.
712 pipeline.AddPass<HloDCE>();
713 if (hlo_module->config().alias_passthrough_params()) {
714 pipeline.AddPass<AliasPassthroughParams>();
715 }
716 pipeline.AddPass<LoopScheduleLinearizer>(GetCanShareBuffer());
717 pipeline.AddPass<CopyInsertion>(GetCanShareBuffer());
718 pipeline.AddPass<GpuSanitizeConstantNames>();
719 return pipeline.Run(hlo_module).status();
720 }
721
OptimizeHloPostLayoutAssignment(HloModule * hlo_module,se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * device_allocator)722 Status GpuCompiler::OptimizeHloPostLayoutAssignment(
723 HloModule* hlo_module, se::StreamExecutor* stream_exec,
724 se::DeviceMemoryAllocator* device_allocator) {
725 const DebugOptions& debug_options = hlo_module->config().debug_options();
726
727 {
728 HloPassPipeline pipeline("hlo normalization");
729 pipeline.AddPass<ReshapeDecomposer>();
730 pipeline.AddPass<ReduceDecomposer>([&](const HloInstruction* r) {
731 return IsReductionFromOrToContiguousDimensions(*r);
732 });
733 if (hlo_module->config().debug_options().xla_gpu_normalize_layouts()) {
734 pipeline.AddPass<LayoutNormalization>();
735 }
736 pipeline.AddPass<BroadcastCanonicalizer>();
737 TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
738 }
739
740 HloPassPipeline pipeline("post-layout_assignment");
741 AddHloVerifier(&pipeline,
742 HloVerifierOpts{}
743 .MakeLayoutSensitive()
744 .WithInstructionCanChangeLayout(
745 LayoutAssignment::InstructionCanChangeLayout)
746 .VerifyBroadcastDimensionsOrder()
747 .VerifyReshapeIsBitcast(),
748 /*debug_only=*/true);
749
750 pipeline.AddPass<ReductionDegenerateDimRemover>();
751 pipeline.AddPass<ReductionLayoutNormalizer>();
752 pipeline.AddPass<ReductionDimensionGrouper>();
753 pipeline.AddPass<HloPassFix<ReductionSplitter>>();
754 pipeline.AddPass<HloPassFix<GpuTreeReductionRewriter>>(
755 stream_exec->GetDeviceDescription().cuda_compute_capability());
756
757 // The LayoutAssignment pass may leave behind kCopy instructions which are
758 // duplicate or NOPs, so remove them with algebraic simplification and CSE.
759 AlgebraicSimplifierOptions options;
760 options.set_is_layout_sensitive(true);
761 options.set_enable_conv_operand_swap(false);
762 // "slow" minmax means we propagate nan.
763 options.set_minmax_propagate_nan(
764 !hlo_module->config().debug_options().xla_gpu_enable_fast_min_max());
765 pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
766
767 // GemmRewriter assumes that all transposes are folded into gemms, but,
768 // since commit 7d529df, this is not always true at this point.
769 // Therefore, rerun transpose folding.
770 pipeline.AddPass<TransposeFolding>(CanFoldTransposeOperandIntoDot,
771 TransposeFolding::NeverFoldTranspose);
772 // Rewrite GEMMs into custom calls.
773 pipeline.AddPass<GemmRewriter>();
774
775 // Rewrite GEMMs with broadcasted inputs as strided GEMMs.
776 pipeline.AddPass<GemmBroadcastFoldingRewriter>();
777
778 // Run conversion again, to catch those matrix multiplications which were not
779 // rewritten into cuBLAS calls.
780 GpuBfloat16Support bf16(/*supports_matrix_multiplication=*/false,
781 stream_exec);
782 pipeline.AddPass<BFloat16Normalization>(&bf16);
783
784 // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization.
785 if (debug_options.xla_gpu_simplify_all_fp_conversions())
786 pipeline.AddPass<SimplifyFPConversions>();
787
788 // Choose the fastest algorithm for each conv.
789 //
790 // We pick the algorithm before fusion so we can generate better HLO. After
791 // GpuConvRewriter, our convolutions are CustomCalls which return a
792 // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of
793 // scratch:
794 //
795 // customcall = (f32[...], f32[0])
796 // return gte(customcall, 0)
797 //
798 // The algorithm picker then chooses the best algorithm, and potentially
799 // increases the scratch space. It replaces customcall with new_tuple,
800 // giving us the following:
801 //
802 // new_customcall = (f32[...], f32[N])
803 // new_tuple = tuple(gte(new_customcall, 0), constant f32[0])
804 // return gte(new_tuple, 0)
805 //
806 // The new tuple and gte instructions then be simplified away, because
807 // nobody is expected to use the scratch value.
808 //
809 // However, if we were to run GpuConvAlgorithmPicker after fusion
810 // the gte(customcall, 0) would probably already be into a fusion node. We
811 // can't simplify across HloComputation boundaries, so in this case we
812 // wouldn't be able to simplify away the new_tuple bits.
813 pipeline.AddPass<GpuConvAlgorithmPicker>(stream_exec, device_allocator);
814
815 // Clean up new_tuple described above.
816 pipeline.AddPass<TupleSimplifier>();
817
818 pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
819 TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
820
821 return OkStatus();
822 }
823
RunHloPasses(std::unique_ptr<HloModule> module,se::StreamExecutor * stream_exec,const CompileOptions & options)824 StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
825 std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
826 const CompileOptions& options) {
827 // We dump the post-optimization HLO in RunBackend so no need to dump it here.
828 XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses");
829 uint64_t start_usecs = tensorflow::Env::Default()->NowMicros();
830 tensorflow::profiler::TraceMe activity(
831 [&] { return absl::StrCat("HLO Transforms:", module->name()); },
832 tensorflow::profiler::TraceMeLevel::kInfo);
833 TF_RETURN_IF_ERROR(
834 OptimizeHloModule(module.get(), stream_exec, options.device_allocator));
835
836 TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get()));
837
838 uint64_t end_usecs = tensorflow::Env::Default()->NowMicros();
839
840 // This won't record values for calls that error out (because if they error
841 // out we have no way of telling how far through the process we got).
842 RecordHloPassesDuration(end_usecs - start_usecs);
843
844 return std::move(module);
845 }
846
DummyCanShareBufferFunction(const HloInstruction *,const HloInstruction *,const ShapeIndex &)847 static std::optional<bool> DummyCanShareBufferFunction(const HloInstruction*,
848 const HloInstruction*,
849 const ShapeIndex&) {
850 return std::nullopt;
851 }
852
AssignBuffers(const HloModule * hlo_module)853 StatusOr<std::unique_ptr<BufferAssignment>> GpuCompiler::AssignBuffers(
854 const HloModule* hlo_module) {
855 TF_ASSIGN_OR_RETURN(HloSchedule hlo_schedule,
856 ScheduleGpuModule(hlo_module, pointer_size_));
857
858 auto buffer_size_bytes_function =
859 [this](const BufferValue& buffer_value) -> int64_t {
860 return GetSizeOfShape(buffer_value.shape(), pointer_size_);
861 };
862
863 TF_ASSIGN_OR_RETURN(
864 std::unique_ptr<BufferAssignment> assignment,
865 BufferAssigner::Run(
866 hlo_module, std::make_unique<SequentialHloOrdering>(hlo_schedule),
867 buffer_size_bytes_function,
868 /*color_alignment=*/
869 [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; },
870 /*allocate_buffers_for_constants=*/true,
871 /*colorer=*/BufferAssigner::DefaultColorer(),
872 /*must_not_live_out=*/{}, GetCanShareBuffer()));
873
874 return std::move(assignment);
875 }
876
877 #if XLA_ENABLE_XLIR
LowerToJitRt(mlir::ModuleOp mlir_module,llvm::StringRef entry_function_name,llvm::ArrayRef<int64_t> buffer_sizes,HloModule * hlo_module,se::StreamExecutor * stream_exec)878 static StatusOr<OwnedJitRtProgram> LowerToJitRt(
879 mlir::ModuleOp mlir_module, llvm::StringRef entry_function_name,
880 llvm::ArrayRef<int64_t> buffer_sizes, HloModule* hlo_module,
881 se::StreamExecutor* stream_exec) {
882 // Forward collective (NCCL) attributes for use by the lowering pipeline.
883 mlir::OpBuilder builder(mlir_module.getContext());
884 mlir::IntegerAttr replica_count_attr =
885 builder.getI64IntegerAttr(hlo_module->config().replica_count());
886 mlir::IntegerAttr num_partitions_attr =
887 builder.getI64IntegerAttr(hlo_module->config().num_partitions());
888 mlir::func::FuncOp func =
889 mlir_module.lookupSymbol<mlir::func::FuncOp>(entry_function_name);
890 func->setAttr("replica_count", replica_count_attr);
891 func->setAttr("num_partitions", num_partitions_attr);
892
893 tensorflow::GpuBinaryOptions options;
894 if (stream_exec == nullptr) {
895 options = tensorflow::GpuBinaryOptions::DefaultGpuBinaryOptions();
896 } else {
897 options.platform_name = stream_exec->platform()->Name();
898 options.gpu_device_info = xla::gpu::GetGpuDeviceInfo(stream_exec);
899 options.cuda_compute_capability =
900 stream_exec->GetDeviceDescription().cuda_compute_capability();
901 options.rocm_compute_capability =
902 stream_exec->GetDeviceDescription().rocm_compute_capability();
903 }
904
905 // Lower LMHLO operations to the JitRt compatible custom calls.
906 TF_RETURN_IF_ERROR(tensorflow::ConvertLmhloToJitRt(
907 mlir_module, {entry_function_name.data(), entry_function_name.size()},
908 buffer_sizes, options));
909 // Serialize module to pass it to GpuExecutable for compilation.
910 std::string serialized_module;
911 llvm::raw_string_ostream os(serialized_module);
912 mlir_module.print(os);
913
914 // TODO(b/232033540): Pass MLIR module directly to JitRt to instantiate an
915 // executable, without forcing serialization.
916 return std::make_unique<GpuExecutable::JitRtProgram>(
917 entry_function_name.str(), os.str(), buffer_sizes.vec(),
918 hlo_module->config().debug_options());
919 }
920 #endif // XLA_ENABLE_XLIR
921
922 using OutputInfoMap =
923 absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>;
924 static Status GetMlirAllocationInfo(mlir::func::FuncOp func,
925 std::vector<BufferAllocation>* allocations,
926 OutputInfoMap* output_info,
927 Shape* output_shape,
928 EntryFunctionAttributes* entry_func_attrs);
929
930 namespace {
931 // Removes all globals from the given module that are both uninitialized and
932 // have no uses within that module.
RemoveUnusedAndUninitializedGlobals(llvm::Module * llvm_module,const std::vector<GpuExecutable::ConstantInfo> & constants)933 void RemoveUnusedAndUninitializedGlobals(
934 llvm::Module* llvm_module,
935 const std::vector<GpuExecutable::ConstantInfo>& constants) {
936 for (const auto& info : constants) {
937 // Empty content means the constant is initialized in the LLVM IR, so we
938 // must not remove it.
939 if (!info.content.empty()) {
940 llvm::GlobalVariable* global =
941 llvm_module->getGlobalVariable(info.symbol_name);
942 CHECK(global != nullptr);
943 if (global->use_empty()) {
944 global->eraseFromParent();
945 }
946 }
947 }
948 }
949 } // namespace
950
951 struct CompileModuleResults {
952 std::unique_ptr<llvm::Module> llvm_module;
953 std::unique_ptr<BufferAssignment> buffer_assignment;
954 std::vector<BufferAllocation> allocations;
955 std::variant<OwnedThunkSequence, OwnedJitRtProgram> executable;
956 EntryFunctionAttributes entry_func_attrs;
957 std::vector<GpuExecutable::ConstantInfo> constants;
958 OutputInfoMap output_info;
959 Shape output_shape;
960 std::string module_name;
961 };
962
963 // The order of `thunk_sequence` corresponds to
964 // `hlo_schedule->ThunkLaunchOrder()`.
CompileModuleToLlvmIrImpl(HloModule * hlo_module,llvm::LLVMContext * llvm_context,const std::string & target_triple,const std::string & data_layout,const std::string & platform_name,const se::Platform::Id platform_id,GpuDeviceInfo gpu_device_info,se::CudaComputeCapability cuda_compute_capability,se::RocmComputeCapability rocm_compute_capability,const HloDataflowAnalysis::CanShareBuffer & can_share_buffer_function,int pointer_size,CompileModuleResults * results,se::StreamExecutor * stream_exec=nullptr)965 static Status CompileModuleToLlvmIrImpl(
966 HloModule* hlo_module, llvm::LLVMContext* llvm_context,
967 const std::string& target_triple, const std::string& data_layout,
968 const std::string& platform_name, const se::Platform::Id platform_id,
969 GpuDeviceInfo gpu_device_info,
970 se::CudaComputeCapability cuda_compute_capability,
971 se::RocmComputeCapability rocm_compute_capability,
972 const HloDataflowAnalysis::CanShareBuffer& can_share_buffer_function,
973 int pointer_size, CompileModuleResults* results,
974 se::StreamExecutor* stream_exec = nullptr) {
975 results->llvm_module = std::make_unique<llvm::Module>("", *llvm_context);
976 results->llvm_module->setTargetTriple(target_triple);
977 results->llvm_module->setDataLayout(data_layout);
978
979 TF_ASSIGN_OR_RETURN(HloSchedule hlo_schedule,
980 ScheduleGpuModule(hlo_module, pointer_size));
981
982 auto buffer_size_bytes_function =
983 [pointer_size](const BufferValue& buffer_value) -> int64_t {
984 return GetSizeOfShape(buffer_value.shape(), pointer_size);
985 };
986
987 TF_ASSIGN_OR_RETURN(
988 results->buffer_assignment,
989 BufferAssigner::Run(
990 hlo_module, std::make_unique<SequentialHloOrdering>(hlo_schedule),
991 buffer_size_bytes_function,
992 /*color_alignment=*/
993 [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; },
994 /*allocate_buffers_for_constants=*/true,
995 /*colorer=*/BufferAssigner::DefaultColorer(),
996 /*must_not_live_out=*/{}, can_share_buffer_function));
997
998 VLOG(1) << "Buffer Assignment Stats for " << hlo_module->name() << "\n"
999 << results->buffer_assignment->GetStats().ToString();
1000 DumpHloModuleIfEnabled(*hlo_module, *results->buffer_assignment,
1001 absl::StrCat("sm_", cuda_compute_capability.ToString(),
1002 "_gpu_after_optimizations"));
1003
1004 uint64_t start_usecs = tensorflow::Env::Default()->NowMicros();
1005 mlir::DialectRegistry registry;
1006 IrEmitterUnnested::GetDependentDialects(registry);
1007 mlir::MLIRContext mlir_context(registry);
1008 mlir::OwningOpRef<mlir::ModuleOp> mlir_module =
1009 mlir::ModuleOp::create(mlir::Builder(&mlir_context).getUnknownLoc());
1010
1011 TF_RETURN_IF_ERROR(
1012 HloToLhloModule(*results->buffer_assignment, *hlo_module, *mlir_module));
1013
1014 results->module_name = mlir::GetNameFromLoc(mlir_module->getLoc());
1015
1016 if (DumpingEnabledForHloModule(*hlo_module)) {
1017 DumpToFileInDirOrStdout(*hlo_module, "lmhlo", mlir_module.get());
1018 }
1019
1020 auto entry_function = mlir::cast<mlir::func::FuncOp>(
1021 mlir_module->lookupSymbol(hlo_module->entry_computation()->name()));
1022
1023 TF_RETURN_IF_ERROR(GetMlirAllocationInfo(
1024 entry_function, &results->allocations, &results->output_info,
1025 &results->output_shape, &results->entry_func_attrs));
1026
1027 if (hlo_module->config().debug_options().xla_gpu_enable_mlir_lowering()) {
1028 mlir::PassManager pm(&mlir_context);
1029 pm.addPass(mlir::createGpuFusionRewritePass());
1030 if (failed(pm.run(mlir_module.get()))) {
1031 return InternalError("Failed to run gpu-fusion-rewrite pass");
1032 }
1033 }
1034
1035 IrEmitterContext ir_emitter_context(
1036 /*hlo_module=*/nullptr, /*buffer_assignment=*/nullptr, platform_name,
1037 gpu_device_info, cuda_compute_capability, rocm_compute_capability,
1038 &mlir_context, results->llvm_module.get());
1039
1040 ir_emitter_context.set_allocations(results->allocations);
1041
1042 TF_ASSIGN_OR_RETURN(
1043 auto ir_emitter,
1044 IrEmitterUnnested::Create(hlo_module->config(), &ir_emitter_context));
1045
1046 {
1047 XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission");
1048
1049 TF_RETURN_IF_ERROR(ir_emitter->EmitLmhloRegion(&entry_function.getBody()));
1050
1051 bool supports_runtime_managed_constants =
1052 // TODO(b/218907125): Implement this feature for ROCm as well.
1053 platform_id != se::rocm::kROCmPlatformId &&
1054 hlo_module->config().debug_options().xla_gpu_enable_shared_constants();
1055 if (supports_runtime_managed_constants) {
1056 // Remove these globals from the generated code to indicate that XLA is
1057 // responsible for allocating and initializing them.
1058 RemoveUnusedAndUninitializedGlobals(ir_emitter_context.llvm_module(),
1059 ir_emitter_context.constants());
1060 }
1061
1062 results->constants = std::move(ir_emitter_context.constants());
1063 uint64_t end_usecs = tensorflow::Env::Default()->NowMicros();
1064
1065 // This won't record values for calls that error out (because if they error
1066 // out we have no way of telling how far through the process we got).
1067 RecordHloToLlvmDuration(end_usecs - start_usecs);
1068 }
1069
1070 #if XLA_ENABLE_XLIR
1071 if (IsJitRtExecutableEnabled(hlo_module->config())) {
1072 std::vector<int64_t> buffer_sizes;
1073 llvm::transform(
1074 results->allocations, std::back_inserter(buffer_sizes),
1075 [](const BufferAllocation& allocation) { return allocation.size(); });
1076 TF_ASSIGN_OR_RETURN(results->executable,
1077 LowerToJitRt(*mlir_module, entry_function.getName(),
1078 buffer_sizes, hlo_module, stream_exec));
1079 return OkStatus();
1080 }
1081 #endif // XLA_ENABLE_XLIR
1082
1083 results->executable = ir_emitter->ConsumeThunkSequence();
1084 return OkStatus();
1085 }
1086
NullDiagnosticHandler(const llvm::DiagnosticInfo & diag_info,void * context)1087 static void NullDiagnosticHandler(const llvm::DiagnosticInfo& diag_info,
1088 void* context) {
1089 std::string error_string;
1090 llvm::raw_string_ostream string_printer(error_string);
1091 llvm::DiagnosticPrinterRawOStream diagnostic_printer(string_printer);
1092 diag_info.print(diagnostic_printer);
1093
1094 VLOG(5) << error_string;
1095 }
1096
1097 StatusOr<std::pair<std::string, std::vector<uint8_t>>>
CompileToTargetBinary(const HloModuleConfig & module_config,std::unique_ptr<llvm::Module> llvm_module,se::StreamExecutor * stream_exec,const CompileOptions & options,const HloModule * debug_module)1098 GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config,
1099 std::unique_ptr<llvm::Module> llvm_module,
1100 se::StreamExecutor* stream_exec,
1101 const CompileOptions& options,
1102 const HloModule* debug_module) {
1103 using BackendCompileResult = std::pair<std::string, std::vector<uint8_t>>;
1104
1105 const auto compile_single_module =
1106 [this, stream_exec, &module_config, debug_module](
1107 llvm::Module* llvm_module, bool relocatable,
1108 std::optional<int> shard_number) -> StatusOr<BackendCompileResult> {
1109 {
1110 XLA_SCOPED_LOGGING_TIMER(
1111 "GpuCompiler::RunBackend - Running LLVM verifier");
1112
1113 llvm_module->getContext().setDiagnosticHandlerCallBack(
1114 NullDiagnosticHandler, nullptr);
1115
1116 std::string err;
1117 llvm::raw_string_ostream err_stream(err);
1118
1119 // verifyModule() returns true if the module is broken.
1120 TF_RET_CHECK(!llvm::verifyModule(*llvm_module, &err_stream))
1121 << "Invalid LLVM IR before optimizations:\n"
1122 << err_stream.str()
1123 << "\nThis probably indicates a bug in the HLO -> LLVM IR "
1124 "lowering. Rerun with --xla_dump_to to get the IR"
1125 << (debug_module
1126 ? absl::StrCat(" and looks for files with name containing: *",
1127 FilenameFor(*debug_module, "", ""), "*")
1128 : ".");
1129 }
1130 GpuVersion gpu_version = GetGpuVersion(stream_exec);
1131 StatusOr<std::pair<std::string, std::vector<uint8_t>>> result =
1132 CompileTargetBinary(module_config, llvm_module, gpu_version,
1133 stream_exec, relocatable, debug_module);
1134
1135 if (!result.ok()) {
1136 return result;
1137 }
1138
1139 const bool should_dump =
1140 DumpingEnabledForHloModule(debug_module ? debug_module->name() : "",
1141 module_config.debug_options());
1142
1143 if (should_dump) {
1144 if (debug_module) {
1145 if (shard_number.has_value()) {
1146 llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module,
1147 /*optimized=*/true,
1148 std::to_string(*shard_number));
1149 } else {
1150 llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module,
1151 /*optimized=*/true);
1152 }
1153 } else {
1154 LOG(ERROR)
1155 << "Dumping is not implemented since the file name cannot be "
1156 "inferred. Please implement (potentially MLIR) module -> "
1157 "filename heuristic.";
1158 }
1159 }
1160
1161 if (user_post_optimization_hook_) {
1162 user_post_optimization_hook_(*llvm_module);
1163 }
1164
1165 // Write PTX to IR dump directory, if IR dumping was requested.
1166 if (should_dump) {
1167 absl::string_view ptx = result->first;
1168 if (debug_module) {
1169 if (shard_number.has_value()) {
1170 DumpToFileInDirOrStdout(*debug_module, "",
1171 std::to_string(*shard_number) + ".ptx", ptx);
1172 } else {
1173 DumpToFileInDirOrStdout(*debug_module, "", "ptx", ptx);
1174 }
1175 } else {
1176 LOG(ERROR)
1177 << "Dumping is not implemented since the file name cannot be "
1178 "inferred. Please implement (potentially MLIR) module -> "
1179 "filename heuristic.";
1180 }
1181 }
1182
1183 return result;
1184 };
1185
1186 tensorflow::thread::ThreadPool* thread_pool;
1187 std::optional<tensorflow::thread::ThreadPool> overriding_thread_pool;
1188 switch (
1189 module_config.debug_options().xla_gpu_force_compilation_parallelism()) {
1190 case 0:
1191 thread_pool = options.thread_pool;
1192 break;
1193 case 1:
1194 thread_pool = nullptr;
1195 break;
1196 default:
1197 overriding_thread_pool.emplace(
1198 tensorflow::Env::Default(), "",
1199 module_config.debug_options()
1200 .xla_gpu_force_compilation_parallelism());
1201 thread_pool = &*overriding_thread_pool;
1202 break;
1203 }
1204
1205 if (!thread_pool) {
1206 return compile_single_module(llvm_module.get(), /*relocatable=*/false,
1207 /*shard_number=*/std::nullopt);
1208 }
1209
1210 // Test whether LinkModules is supported.
1211 if (this->LinkModules(stream_exec, {}).status().code() ==
1212 tensorflow::error::Code::UNIMPLEMENTED) {
1213 return compile_single_module(llvm_module.get(), /*relocatable=*/false,
1214 /*shard_number=*/std::nullopt);
1215 }
1216
1217 std::vector<std::unique_ptr<llvm::Module>> llvm_modules;
1218 int num_functions = 0;
1219 for (llvm::Function& func : llvm_module->functions()) {
1220 if (!func.isDeclaration() &&
1221 func.getLinkage() == llvm::GlobalValue::LinkageTypes::ExternalLinkage) {
1222 num_functions++;
1223 }
1224 }
1225
1226 // Record the name of some constant global variables and their initializers.
1227 // We'll change the linkage type of these variables from external to internal
1228 // to ensure constant-folding works properly after calling llvm::SplitModule.
1229 llvm::DenseMap<llvm::StringRef, llvm::Constant*> const_initializer_map;
1230 for (llvm::GlobalVariable& gv : llvm_module->globals()) {
1231 if (gv.hasName() && gv.isConstant() && gv.hasInitializer() &&
1232 gv.hasExternalLinkage()) {
1233 llvm::Constant* initializer = gv.getInitializer();
1234 unsigned int num_elements = 0;
1235 if (auto* caz =
1236 llvm::dyn_cast<llvm::ConstantAggregateZero>(initializer)) {
1237 num_elements = caz->getElementCount().getFixedValue();
1238 } else if (auto* cds = llvm::dyn_cast<llvm::ConstantDataSequential>(
1239 initializer)) {
1240 num_elements = cds->getNumElements();
1241 }
1242 if (num_elements > 0) {
1243 const_initializer_map[gv.getName()] = initializer;
1244 }
1245 }
1246 }
1247
1248 llvm::SplitModule(
1249 *llvm_module,
1250 std::max<unsigned>(
1251 1, std::min<unsigned>(thread_pool->NumThreads(), num_functions)),
1252 [&](std::unique_ptr<llvm::Module> module) {
1253 // Change the linkage type of some global constant variables to internal
1254 for (llvm::GlobalVariable& gv : module->globals()) {
1255 if (gv.hasName() && gv.isConstant() && !gv.hasInitializer() &&
1256 const_initializer_map.count(gv.getName()) != 0) {
1257 gv.setInitializer(const_initializer_map[gv.getName()]);
1258 gv.setLinkage(llvm::GlobalValue::InternalLinkage);
1259 }
1260 }
1261 llvm_modules.push_back(std::move(module));
1262 },
1263 /*PreserveLocals=*/true);
1264
1265 std::vector<StatusOr<BackendCompileResult>> compile_results(
1266 llvm_modules.size());
1267 tensorflow::BlockingCounter counter(llvm_modules.size());
1268 for (int i = 0; i < llvm_modules.size(); i++) {
1269 thread_pool->Schedule(
1270 [&compile_results, compile_single_module, i, &llvm_modules, &counter] {
1271 llvm::Module* original_module = llvm_modules[i].get();
1272 llvm::LLVMContext context;
1273 std::string buffer;
1274 llvm::raw_string_ostream error(buffer);
1275
1276 std::unique_ptr<llvm::Module> new_llvm_module;
1277 // Switch to a new context by dumping and re-parsing LLVM IR. Each
1278 // thread has its own context to avoid race conditions.
1279 {
1280 std::string ir;
1281 {
1282 llvm::raw_string_ostream os(ir);
1283 original_module->print(os, nullptr);
1284 }
1285 llvm::SMDiagnostic err;
1286 new_llvm_module = llvm::parseAssemblyString(ir, err, context);
1287 if (!new_llvm_module) {
1288 std::string err_string;
1289 llvm::raw_string_ostream os(err_string);
1290 err.print(/*ProgName=*/nullptr, os, /*ShowColors=*/false);
1291 LOG(FATAL) << "Failed to parse IR: " << err_string;
1292 }
1293 }
1294
1295 compile_results[i] = compile_single_module(
1296 new_llvm_module.get(), /*relocatable=*/true, /*shard_number=*/i);
1297 counter.DecrementCount();
1298 });
1299 }
1300 counter.Wait();
1301
1302 std::string ptx_snippets;
1303 std::vector<std::vector<uint8_t>> submodule_compile_results;
1304 for (auto& maybe_result : compile_results) {
1305 TF_ASSIGN_OR_RETURN(auto result, maybe_result);
1306 if (result.second.empty()) {
1307 continue;
1308 }
1309 ptx_snippets += result.first;
1310 ptx_snippets += "\n";
1311 submodule_compile_results.push_back(result.second);
1312 }
1313
1314 auto maybe_backend_result =
1315 this->LinkModules(stream_exec, std::move(submodule_compile_results));
1316 if (!maybe_backend_result.ok()) {
1317 LOG(ERROR) << "The CUDA linking API did not work. Please use "
1318 "XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 to "
1319 "bypass it, but expect to get longer compilation time due to "
1320 "the lack of multi-threading.";
1321 return maybe_backend_result.status();
1322 }
1323
1324 return std::make_pair(ptx_snippets, std::move(*maybe_backend_result));
1325 }
1326
RunBackend(std::unique_ptr<HloModule> module,se::StreamExecutor * stream_exec,const CompileOptions & options)1327 StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
1328 std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
1329 const CompileOptions& options) {
1330 VLOG(1) << "Starting to compile HLO module " << module->name();
1331 XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend");
1332 std::string slow_compilation_msg =
1333 absl::StrCat("Compiling module ", module->name());
1334 auto slow_compile_alarm = SlowCompilationAlarm(slow_compilation_msg);
1335
1336 TF_RET_CHECK(stream_exec != nullptr);
1337
1338 llvm::LLVMContext llvm_context;
1339
1340 GpuDeviceInfo gpu_device_info = GetGpuDeviceInfo(stream_exec);
1341
1342 if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) {
1343 HloCostAnalysis::Options options{ShapeSizeBytesFunction()};
1344 options.set_bytes_per_second(
1345 stream_exec->GetDeviceDescription().memory_bandwidth());
1346 GpuHloCostAnalysis cost_analysis(options);
1347 TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis));
1348 VLOG(1) << "HLO memory read+written: "
1349 << tensorflow::strings::HumanReadableNumBytes(
1350 cost_analysis.bytes_accessed());
1351 if (module->config().hlo_profiling_enabled()) {
1352 LOG(ERROR) << "--xla_hlo_profile for GPU is unsupported.";
1353 }
1354 }
1355
1356 CompileModuleResults compile_module_results;
1357 TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
1358 module.get(), &llvm_context, target_triple_, data_layout_,
1359 stream_exec->platform()->Name(), stream_exec->platform()->id(),
1360 gpu_device_info,
1361 stream_exec->GetDeviceDescription().cuda_compute_capability(),
1362 stream_exec->GetDeviceDescription().rocm_compute_capability(),
1363 GetCanShareBuffer(), pointer_size_, &compile_module_results,
1364 stream_exec));
1365
1366 if (user_pre_optimization_hook_) {
1367 user_pre_optimization_hook_(*compile_module_results.llvm_module);
1368 }
1369 std::string ir_module_string_before_opt;
1370 const bool embed_ir_in_executable =
1371 module->config().debug_options().xla_embed_ir_in_executable();
1372 if (embed_ir_in_executable) {
1373 ir_module_string_before_opt =
1374 llvm_ir::DumpModuleToString(*compile_module_results.llvm_module);
1375 }
1376
1377 llvm_ir::DumpIrIfEnabled(*module, *compile_module_results.llvm_module,
1378 /*optimized=*/false);
1379
1380 using BackendCompileResult = std::pair<std::string, std::vector<uint8_t>>;
1381 TF_ASSIGN_OR_RETURN(
1382 BackendCompileResult backend_result,
1383 CompileToTargetBinary(module->config(),
1384 std::move(compile_module_results.llvm_module),
1385 stream_exec, options, module.get()));
1386 if (DumpingEnabledForHloModule(*module) &&
1387 std::holds_alternative<OwnedThunkSequence>(
1388 compile_module_results.executable)) {
1389 const ThunkSequence& thunk_sequence =
1390 *std::get<OwnedThunkSequence>(compile_module_results.executable);
1391 DumpToFileInDirOrStdout(*module, "", "thunk_sequence",
1392 thunk_sequence.ToString());
1393 }
1394
1395 auto buffer_assignment_proto = std::make_unique<BufferAssignmentProto>(
1396 compile_module_results.buffer_assignment->ToProto());
1397
1398 // Make it shared to be captured in the following lambda.
1399 std::shared_ptr<const BufferAssignment> buffer_assignment(
1400 std::move(compile_module_results.buffer_assignment));
1401
1402 GpuVersion gpu_version = GetGpuVersion(stream_exec);
1403 TF_ASSIGN_OR_RETURN(
1404 auto gpu_executable,
1405 GpuExecutable::Create(
1406 {std::move(backend_result.first), std::move(backend_result.second),
1407 gpu_version, std::move(compile_module_results.executable),
1408 compile_module_results.entry_func_attrs,
1409 std::move(compile_module_results.constants),
1410 std::move(compile_module_results.output_info),
1411 compile_module_results.module_name,
1412 compile_module_results.output_shape,
1413 std::move(compile_module_results.allocations),
1414 std::move(buffer_assignment_proto),
1415 [buffer_assignment] { return buffer_assignment->ToVerboseString(); },
1416 std::move(module)}));
1417 if (embed_ir_in_executable) {
1418 DCHECK_NE("", ir_module_string_before_opt);
1419 gpu_executable->set_ir_module_string(ir_module_string_before_opt);
1420 }
1421
1422 // Dump computation proto state and buffer assignment for debug and test, if
1423 // dump or embed_ir_in_executable is enabled.
1424 if (embed_ir_in_executable ||
1425 DumpingEnabledForHloModule(gpu_executable->module())) {
1426 auto hlo_proto = std::make_unique<HloProto>();
1427 *hlo_proto->mutable_hlo_module() = gpu_executable->module().ToProto();
1428 *hlo_proto->mutable_buffer_assignment() = buffer_assignment->ToProto();
1429 gpu_executable->set_hlo_proto(std::move(hlo_proto));
1430 }
1431 gpu_executable->set_debug_info(buffer_assignment->GetStats().ToString());
1432 return static_cast<std::unique_ptr<Executable>>(std::move(gpu_executable));
1433 }
1434
1435 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,const AotCompilationOptions & options)1436 GpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
1437 const AotCompilationOptions& options) {
1438 #if XLA_ENABLE_XLIR
1439 CHECK(options.PlatformId() == se::cuda::kCudaPlatformId);
1440 CHECK(options.executor() != nullptr);
1441 auto stream_exec = options.executor();
1442
1443 std::vector<std::unique_ptr<HloModule>> modules =
1444 module_group->ConsumeModules();
1445 std::vector<std::unique_ptr<AotCompilationResult>> results;
1446
1447 for (const auto& module : modules) {
1448 llvm::LLVMContext llvm_context;
1449 GpuDeviceInfo gpu_device_info = GetGpuDeviceInfo(stream_exec);
1450
1451 // Compile the module
1452 CompileModuleResults compile_module_results;
1453 TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
1454 module.get(), &llvm_context, target_triple_, data_layout_,
1455 stream_exec->platform()->Name(), stream_exec->platform()->id(),
1456 gpu_device_info,
1457 stream_exec->GetDeviceDescription().cuda_compute_capability(),
1458 stream_exec->GetDeviceDescription().rocm_compute_capability(),
1459 GetCanShareBuffer(), pointer_size_, &compile_module_results));
1460 auto& compiled_executable = compile_module_results.executable;
1461
1462 if (!std::holds_alternative<OwnedJitRtProgram>(compiled_executable)) {
1463 return InternalError("JitRtProgram not provided");
1464 }
1465
1466 const auto& program = std::get<OwnedJitRtProgram>(compiled_executable);
1467
1468 // Options for the default JitRt compilation pipeline.
1469 runtime::CompilationPipelineOptions copts;
1470
1471 // Options for constructing JitRt JitExecutable.
1472 runtime::JitExecutable::Options opts;
1473 opts.specialization = runtime::JitExecutable::Specialization::kDisabled;
1474 opts.compiler.register_dialects =
1475 runtime::RegisterDefaultXlaRuntimeDialects;
1476
1477 // Register JitRt Gpu runtime custom calls with the linker.
1478 opts.compiler.symbols_binding = runtime::ToSymbolsBinding(
1479 JitRtGpuCustomCalls(), PopulateXlaTypeIdNames);
1480
1481 opts.compiler.create_compilation_pipeline = [copts](mlir::PassManager& pm) {
1482 runtime::CreateDefaultXlaRuntimeCompilationPipeline(pm, copts);
1483 };
1484
1485 // Instantiate new JitExecutable from the MLIR source.
1486 auto jit_executable = runtime::JitExecutable::Instantiate(
1487 program->module, program->entry_point, opts);
1488 if (auto err = jit_executable.takeError())
1489 return InternalError("Failed to compile JitRt program: %s",
1490 tfrt::StrCat(err));
1491
1492 // For static shapes we can always serialize only the default executable.
1493 runtime::Executable& executable = jit_executable->DefaultExecutable().get();
1494
1495 // Check if JitRt executable saved the compilation result.
1496 std::unique_ptr<llvm::MemoryBuffer> obj_file = executable.obj_file();
1497 if (!obj_file)
1498 return InternalError("JitRt executable didn't save the obj file");
1499
1500 std::string data(obj_file->getBuffer().data(),
1501 obj_file->getBuffer().size());
1502 results.emplace_back(std::make_unique<xla::gpu::JitRtAotCompilationResult>(
1503 module->ToProto(), data, program->module,
1504 compile_module_results.entry_func_attrs));
1505 }
1506 return std::move(results);
1507 #else
1508 return Unimplemented("");
1509 #endif // XLA_ENABLE_XLIR
1510 }
1511
ShapeSizeBytesFunction() const1512 HloCostAnalysis::ShapeSizeFunction GpuCompiler::ShapeSizeBytesFunction() const {
1513 // Capture just the pointer size, not the entire GpuCompiler object.
1514 return [pointer_size = pointer_size_](const Shape& shape) {
1515 return GetSizeOfShape(shape, pointer_size);
1516 };
1517 }
1518
CompileModuleToLlvmIr(HloModule * hlo_module,llvm::LLVMContext * llvm_context,const std::string & target_triple,const std::string & data_layout,const std::string & platform_name,const se::Platform::Id platform_id,GpuDeviceInfo gpu_device_info,se::CudaComputeCapability cuda_compute_capability,se::RocmComputeCapability rocm_compute_capability,int pointer_size)1519 StatusOr<std::unique_ptr<llvm::Module>> CompileModuleToLlvmIr(
1520 HloModule* hlo_module, llvm::LLVMContext* llvm_context,
1521 const std::string& target_triple, const std::string& data_layout,
1522 const std::string& platform_name, const se::Platform::Id platform_id,
1523 GpuDeviceInfo gpu_device_info,
1524 se::CudaComputeCapability cuda_compute_capability,
1525 se::RocmComputeCapability rocm_compute_capability, int pointer_size) {
1526 CompileModuleResults results;
1527 TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
1528 hlo_module, llvm_context, target_triple, data_layout, platform_name,
1529 platform_id, gpu_device_info, cuda_compute_capability,
1530 rocm_compute_capability, DummyCanShareBufferFunction, pointer_size,
1531 &results));
1532 return std::move(results.llvm_module);
1533 }
1534
1535 // Analyze the function signature to reconstruct a vector of BufferAllocation
1536 // objects, as well as other output information.
1537 //
1538 // This function also serves as a half-baked verifier for function arg
1539 // attributes, since a full verifier doens't exist yet.
GetMlirAllocationInfo(mlir::func::FuncOp func,std::vector<BufferAllocation> * allocations,OutputInfoMap * output_info,Shape * output_shape,EntryFunctionAttributes * entry_func_attrs)1540 static Status GetMlirAllocationInfo(mlir::func::FuncOp func,
1541 std::vector<BufferAllocation>* allocations,
1542 OutputInfoMap* output_info,
1543 Shape* output_shape,
1544 EntryFunctionAttributes* entry_func_attrs) {
1545 CHECK(allocations->empty());
1546 allocations->reserve(func.getNumArguments());
1547
1548 std::vector<int64_t> buffer_sizes;
1549 for (int i = 0; i < func.getNumArguments(); i++) {
1550 mlir::BlockArgument arg = func.getArgument(i);
1551
1552 TF_RET_CHECK(arg.getType().isa<mlir::ShapedType>());
1553 mlir::ShapedType type = arg.getType().cast<mlir::ShapedType>();
1554 TF_ASSIGN_OR_RETURN(auto element_type_bytes,
1555 GetElementTypeBytes(type.getElementType()));
1556 size_t size = type.getNumElements() * element_type_bytes;
1557 buffer_sizes.push_back(size);
1558 }
1559
1560 for (int i = 0; i < func.getNumArguments(); i++) {
1561 for (const mlir::NamedAttribute& attr : func.getArgAttrs(i)) {
1562 TF_RET_CHECK(attr.getName() == "lmhlo.params" ||
1563 attr.getName() == "lmhlo.param_shape_index" ||
1564 attr.getName() == "lmhlo.constant_name" ||
1565 attr.getName() == "lmhlo.must_alias" ||
1566 attr.getName() == "lmhlo.output_index");
1567 }
1568 }
1569
1570 // Encode buffer parameter metadata in a proto for persisting, because BEF
1571 // doesn't persist function attributes.
1572 for (int i = 0; i < func.getNumArguments(); i++) {
1573 auto buffer = entry_func_attrs->add_buffers();
1574 if (auto param_attr = func.getArgAttr(i, "lmhlo.params")) {
1575 buffer->set_lmhlo_params_present(true);
1576 buffer->set_lmhlo_params(param_attr.cast<mlir::IntegerAttr>().getInt());
1577 }
1578 if (auto shape_index_attr = func.getArgAttr(i, "lmhlo.param_shape_index")) {
1579 auto param_shape_index = buffer->mutable_lmhlo_param_shape_index();
1580 for (const llvm::APInt& element :
1581 shape_index_attr.cast<mlir::DenseIntElementsAttr>()) {
1582 param_shape_index->add_indices(element.getSExtValue());
1583 }
1584 }
1585 if (auto constant_name_attr = func.getArgAttr(i, "lmhlo.constant_name")) {
1586 buffer->set_lmhlo_constant_name(
1587 constant_name_attr.cast<mlir::StringAttr>().str());
1588 }
1589 if (func.getArgAttr(i, "lmhlo.must_alias")) {
1590 buffer->set_lmhlo_must_alias(true);
1591 }
1592 if (auto output_index_attr = func.getArgAttr(i, "lmhlo.output_index")) {
1593 auto output_index = buffer->mutable_lmhlo_output_index();
1594 for (const llvm::APInt& element :
1595 output_index_attr.cast<mlir::DenseIntElementsAttr>()) {
1596 output_index->add_indices(element.getSExtValue());
1597 }
1598 }
1599 }
1600 entry_func_attrs->set_result_xla_shape(
1601 func->getAttrOfType<mlir::StringAttr>("result_xla_shape")
1602 .getValue()
1603 .str());
1604
1605 return GpuExecutable::SetUpMlirAllocation(func, buffer_sizes, allocations,
1606 output_info, output_shape);
1607 }
1608
CompileLmhloToExecutable(GpuCompiler * compiler,mlir::ModuleOp module,std::string module_name,const HloModuleConfig & module_config,const Compiler::CompileOptions & options,absl::string_view entry_function_name,se::StreamExecutor * stream_exec,std::unique_ptr<llvm::Module> llvm_module,IrEmitterContext * ir_emitter_context)1609 StatusOr<std::unique_ptr<Executable>> CompileLmhloToExecutable(
1610 GpuCompiler* compiler, mlir::ModuleOp module, std::string module_name,
1611 const HloModuleConfig& module_config,
1612 const Compiler::CompileOptions& options,
1613 absl::string_view entry_function_name, se::StreamExecutor* stream_exec,
1614 std::unique_ptr<llvm::Module> llvm_module,
1615 IrEmitterContext* ir_emitter_context) {
1616 mlir::func::FuncOp entry_function =
1617 mlir::cast<mlir::func::FuncOp>(module.lookupSymbol(llvm::StringRef(
1618 entry_function_name.data(), entry_function_name.size())));
1619
1620 std::vector<BufferAllocation> allocations;
1621 OutputInfoMap output_info;
1622 Shape output_shape;
1623 EntryFunctionAttributes entry_func_attrs;
1624 TF_RETURN_IF_ERROR(GetMlirAllocationInfo(entry_function, &allocations,
1625 &output_info, &output_shape,
1626 &entry_func_attrs));
1627
1628 TF_RET_CHECK(!allocations.empty());
1629
1630 ir_emitter_context->set_allocations(allocations);
1631
1632 TF_ASSIGN_OR_RETURN(auto ir_emitter, IrEmitterUnnested::Create(
1633 module_config, ir_emitter_context));
1634 TF_RETURN_IF_ERROR(ir_emitter->EmitLmhloRegion(&entry_function.getBody()));
1635
1636 bool supports_runtime_managed_constants =
1637 // TODO(b/218907125): Implement this feature for ROCm as well.
1638 compiler->PlatformId() != se::rocm::kROCmPlatformId;
1639 if (supports_runtime_managed_constants) {
1640 // Remove these globals from the generated code to indicate that XLA is
1641 // responsible for allocating and initializing them.
1642 RemoveUnusedAndUninitializedGlobals(ir_emitter_context->llvm_module(),
1643 ir_emitter_context->constants());
1644 }
1645
1646 auto thunk_sequence = ir_emitter->ConsumeThunkSequence();
1647
1648 using BackendCompileResult = std::pair<std::string, std::vector<uint8_t>>;
1649 TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result,
1650 compiler->CompileToTargetBinary(
1651 module_config, std::move(llvm_module), stream_exec,
1652 options, /*debug_module=*/nullptr));
1653
1654 GpuVersion gpu_version = compiler->GetGpuVersion(stream_exec);
1655 return GpuExecutable::Create(
1656 {std::move(backend_result.first), std::move(backend_result.second),
1657 gpu_version, std::move(thunk_sequence), entry_func_attrs,
1658 std::move(ir_emitter_context->constants()), std::move(output_info),
1659 module_name, output_shape, std::move(allocations)});
1660 }
1661
1662 } // namespace gpu
1663 } // namespace xla
1664