xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/gpu_compiler.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_
18 
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "tensorflow/compiler/xla/service/executable.h"
26 #include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h"
27 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
28 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
29 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
30 #include "tensorflow/compiler/xla/service/hlo_module.h"
31 #include "tensorflow/compiler/xla/service/llvm_compiler.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
36 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
37 
38 namespace xla {
39 namespace gpu {
40 
41 // TODO(b/232263665): It should be shared between GPU and CPU.
42 class JitRtAotCompilationResult : public AotCompilationResult {
43  public:
FromString(const std::string & serialized)44   static StatusOr<std::unique_ptr<JitRtAotCompilationResult>> FromString(
45       const std::string& serialized) {
46     JitRtExecutableProto jitrt_executable;
47     if (!jitrt_executable.ParseFromString(serialized)) {
48       return InternalError("Failed to parse serialized JitRtExecutableProto.");
49     }
50     return std::unique_ptr<JitRtAotCompilationResult>(
51         new JitRtAotCompilationResult(std::move(jitrt_executable)));
52   }
53 
JitRtAotCompilationResult(HloModuleProto hlo,const std::string & obj_file,const std::string & mlir_module,EntryFunctionAttributes entry_func_attrs)54   JitRtAotCompilationResult(HloModuleProto hlo, const std::string& obj_file,
55                             const std::string& mlir_module,
56                             EntryFunctionAttributes entry_func_attrs) {
57     *jitrt_executable_.mutable_hlo_module_proto() = hlo;
58     *jitrt_executable_.mutable_entry_func_attrs() = entry_func_attrs;
59     jitrt_executable_.set_obj_file(obj_file);
60     jitrt_executable_.set_mlir_module(mlir_module);
61   }
62 
SerializeAsString()63   StatusOr<std::string> SerializeAsString() const override {
64     return jitrt_executable_.SerializeAsString();
65   }
66 
67   StatusOr<std::unique_ptr<Executable>> LoadExecutable(
68       Compiler* compiler, se::StreamExecutor* executor) const override;
69 
70  private:
JitRtAotCompilationResult(JitRtExecutableProto jitrt_executable)71   explicit JitRtAotCompilationResult(JitRtExecutableProto jitrt_executable)
72       : jitrt_executable_(std::move(jitrt_executable)) {}
73 
74   JitRtExecutableProto jitrt_executable_;
75 };
76 
77 // The GPU compiler generates efficient GPU executables.
78 class GpuCompiler : public LLVMCompiler {
79  public:
80   GpuCompiler(se::Platform::Id platform_id, const char* target_triple,
81               const char* data_layout);
~GpuCompiler()82   ~GpuCompiler() override {}
83 
84   using LLVMCompiler::Compile;
85 
86   StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
87       std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
88       const CompileOptions& options) override;
89 
90   StatusOr<std::unique_ptr<BufferAssignment>> AssignBuffers(
91       const HloModule* hlo_module) override;
92 
93   virtual GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) = 0;
94 
95   StatusOr<std::unique_ptr<Executable>> RunBackend(
96       std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
97       const CompileOptions& options) override;
98 
99   StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
100   CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
101                      AotCompilationOptions const& options) override;
102 
103   StatusOr<std::pair<std::string, std::vector<uint8_t>>> CompileToTargetBinary(
104       const HloModuleConfig& module_config,
105       std::unique_ptr<llvm::Module> llvm_module,
106       se::StreamExecutor* stream_exec, const CompileOptions& options,
107       const HloModule* debug_module);
108 
PlatformId()109   se::Platform::Id PlatformId() const override { return platform_id_; }
110 
111   HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override;
112 
113   // Returns a (deserialized) AotCompilationResult from a serialized
114   // AotCompilationResult.
LoadAotCompilationResult(const std::string & serialized_aot_result)115   StatusOr<std::unique_ptr<AotCompilationResult>> LoadAotCompilationResult(
116       const std::string& serialized_aot_result) override {
117     return JitRtAotCompilationResult::FromString(serialized_aot_result);
118   }
119 
120  protected:
121   virtual Status OptimizeHloPostLayoutAssignment(
122       HloModule* hlo_module, se::StreamExecutor* stream_exec,
123       se::DeviceMemoryAllocator* device_allocator);
124 
125  private:
126   Status OptimizeHloModule(HloModule* hlo_module,
127                            se::StreamExecutor* stream_exec,
128                            se::DeviceMemoryAllocator* device_allocator);
129 
130   virtual Status OptimizeHloConvolutionCanonicalization(
131       HloModule* hlo_module, se::StreamExecutor* stream_exec,
132       se::DeviceMemoryAllocator* device_allocator) = 0;
133 
GetCanShareBuffer()134   virtual HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() {
135     return
136         [](const HloInstruction*, const HloInstruction*,
137            const ShapeIndex&) -> std::optional<bool> { return std::nullopt; };
138   }
139 
140   // TODO(timshen): Replace `debug_module` with some portable debug information
141   // that accommodates both HLO and MLIR.
142   virtual StatusOr<std::pair<std::string, std::vector<uint8_t>>>
143   CompileTargetBinary(const HloModuleConfig& module_config,
144                       llvm::Module* llvm_module, GpuVersion gpu_version,
145                       se::StreamExecutor* stream_exec, bool relocatable,
146                       const HloModule* debug_module) = 0;
147 
148   Status PrepareHloModuleForIrEmitting(HloModule* hlo_module);
149 
LinkModules(se::StreamExecutor * stream_exec,std::vector<std::vector<uint8_t>> modules)150   virtual StatusOr<std::vector<uint8_t>> LinkModules(
151       se::StreamExecutor* stream_exec,
152       std::vector<std::vector<uint8_t>> modules) {
153     return Unimplemented("LinkModules is not implemented.");
154   }
155 
156   se::Platform::Id platform_id_;
157 
158   // The triple that represents our target.
159   const char* target_triple_;
160 
161   // The data layout of the emitted module.
162   const char* data_layout_;
163 
164   // The size in bytes of a pointer. Used by ShapeSizeBytesFunction.
165   const int64_t pointer_size_;
166 
167   GpuCompiler(const GpuCompiler&) = delete;
168   GpuCompiler& operator=(const GpuCompiler&) = delete;
169 };
170 
171 // Compile `hlo_module` using XLA GPU and return the LLVM module thus generated.
172 // The GpuExecutable (and the Thunks that are part of it) are not returned.
173 StatusOr<std::unique_ptr<llvm::Module>> CompileModuleToLlvmIr(
174     HloModule* hlo_module, llvm::LLVMContext* llvm_context,
175     const std::string& target_triple, const std::string& data_layout,
176     const std::string& platform_name, const se::Platform::Id platform_id,
177     GpuDeviceInfo gpu_device_info,
178     se::CudaComputeCapability cuda_compute_capability,
179     se::RocmComputeCapability rocm_compute_capability, int pointer_size);
180 
181 // Compiles the given LMHLO module to an executable.
182 // ir_emitter_context should be partially populated: buffer_assignment
183 // or buffer_allocations should not be populated, while other fields should be
184 // populated (or left empty if that field is optional).
185 //
186 // NOTE: buffer_assignment will be gone from ir_emitter_context once LMHLO
187 // transition is done.
188 StatusOr<std::unique_ptr<Executable>> CompileLmhloToExecutable(
189     GpuCompiler* compiler, mlir::ModuleOp module, std::string module_name,
190     const HloModuleConfig& module_config,
191     const Compiler::CompileOptions& options,
192     absl::string_view entry_function_name, se::StreamExecutor* stream_exec,
193     std::unique_ptr<llvm::Module> llvm_module,
194     IrEmitterContext* ir_emitter_context);
195 
196 }  // namespace gpu
197 }  // namespace xla
198 
199 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_
200