1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_ 18 19 #include <memory> 20 #include <string> 21 #include <utility> 22 #include <vector> 23 24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 25 #include "tensorflow/compiler/xla/service/executable.h" 26 #include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" 27 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" 28 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" 29 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" 30 #include "tensorflow/compiler/xla/service/hlo_module.h" 31 #include "tensorflow/compiler/xla/service/llvm_compiler.h" 32 #include "tensorflow/compiler/xla/statusor.h" 33 #include "tensorflow/compiler/xla/types.h" 34 #include "tensorflow/compiler/xla/util.h" 35 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 36 #include "tensorflow/stream_executor/stream_executor_pimpl.h" 37 38 namespace xla { 39 namespace gpu { 40 41 // TODO(b/232263665): It should be shared between GPU and CPU. 42 class JitRtAotCompilationResult : public AotCompilationResult { 43 public: FromString(const std::string & serialized)44 static StatusOr<std::unique_ptr<JitRtAotCompilationResult>> FromString( 45 const std::string& serialized) { 46 JitRtExecutableProto jitrt_executable; 47 if (!jitrt_executable.ParseFromString(serialized)) { 48 return InternalError("Failed to parse serialized JitRtExecutableProto."); 49 } 50 return std::unique_ptr<JitRtAotCompilationResult>( 51 new JitRtAotCompilationResult(std::move(jitrt_executable))); 52 } 53 JitRtAotCompilationResult(HloModuleProto hlo,const std::string & obj_file,const std::string & mlir_module,EntryFunctionAttributes entry_func_attrs)54 JitRtAotCompilationResult(HloModuleProto hlo, const std::string& obj_file, 55 const std::string& mlir_module, 56 EntryFunctionAttributes entry_func_attrs) { 57 *jitrt_executable_.mutable_hlo_module_proto() = hlo; 58 *jitrt_executable_.mutable_entry_func_attrs() = entry_func_attrs; 59 jitrt_executable_.set_obj_file(obj_file); 60 jitrt_executable_.set_mlir_module(mlir_module); 61 } 62 SerializeAsString()63 StatusOr<std::string> SerializeAsString() const override { 64 return jitrt_executable_.SerializeAsString(); 65 } 66 67 StatusOr<std::unique_ptr<Executable>> LoadExecutable( 68 Compiler* compiler, se::StreamExecutor* executor) const override; 69 70 private: JitRtAotCompilationResult(JitRtExecutableProto jitrt_executable)71 explicit JitRtAotCompilationResult(JitRtExecutableProto jitrt_executable) 72 : jitrt_executable_(std::move(jitrt_executable)) {} 73 74 JitRtExecutableProto jitrt_executable_; 75 }; 76 77 // The GPU compiler generates efficient GPU executables. 78 class GpuCompiler : public LLVMCompiler { 79 public: 80 GpuCompiler(se::Platform::Id platform_id, const char* target_triple, 81 const char* data_layout); ~GpuCompiler()82 ~GpuCompiler() override {} 83 84 using LLVMCompiler::Compile; 85 86 StatusOr<std::unique_ptr<HloModule>> RunHloPasses( 87 std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, 88 const CompileOptions& options) override; 89 90 StatusOr<std::unique_ptr<BufferAssignment>> AssignBuffers( 91 const HloModule* hlo_module) override; 92 93 virtual GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) = 0; 94 95 StatusOr<std::unique_ptr<Executable>> RunBackend( 96 std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, 97 const CompileOptions& options) override; 98 99 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> 100 CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group, 101 AotCompilationOptions const& options) override; 102 103 StatusOr<std::pair<std::string, std::vector<uint8_t>>> CompileToTargetBinary( 104 const HloModuleConfig& module_config, 105 std::unique_ptr<llvm::Module> llvm_module, 106 se::StreamExecutor* stream_exec, const CompileOptions& options, 107 const HloModule* debug_module); 108 PlatformId()109 se::Platform::Id PlatformId() const override { return platform_id_; } 110 111 HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override; 112 113 // Returns a (deserialized) AotCompilationResult from a serialized 114 // AotCompilationResult. LoadAotCompilationResult(const std::string & serialized_aot_result)115 StatusOr<std::unique_ptr<AotCompilationResult>> LoadAotCompilationResult( 116 const std::string& serialized_aot_result) override { 117 return JitRtAotCompilationResult::FromString(serialized_aot_result); 118 } 119 120 protected: 121 virtual Status OptimizeHloPostLayoutAssignment( 122 HloModule* hlo_module, se::StreamExecutor* stream_exec, 123 se::DeviceMemoryAllocator* device_allocator); 124 125 private: 126 Status OptimizeHloModule(HloModule* hlo_module, 127 se::StreamExecutor* stream_exec, 128 se::DeviceMemoryAllocator* device_allocator); 129 130 virtual Status OptimizeHloConvolutionCanonicalization( 131 HloModule* hlo_module, se::StreamExecutor* stream_exec, 132 se::DeviceMemoryAllocator* device_allocator) = 0; 133 GetCanShareBuffer()134 virtual HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() { 135 return 136 [](const HloInstruction*, const HloInstruction*, 137 const ShapeIndex&) -> std::optional<bool> { return std::nullopt; }; 138 } 139 140 // TODO(timshen): Replace `debug_module` with some portable debug information 141 // that accommodates both HLO and MLIR. 142 virtual StatusOr<std::pair<std::string, std::vector<uint8_t>>> 143 CompileTargetBinary(const HloModuleConfig& module_config, 144 llvm::Module* llvm_module, GpuVersion gpu_version, 145 se::StreamExecutor* stream_exec, bool relocatable, 146 const HloModule* debug_module) = 0; 147 148 Status PrepareHloModuleForIrEmitting(HloModule* hlo_module); 149 LinkModules(se::StreamExecutor * stream_exec,std::vector<std::vector<uint8_t>> modules)150 virtual StatusOr<std::vector<uint8_t>> LinkModules( 151 se::StreamExecutor* stream_exec, 152 std::vector<std::vector<uint8_t>> modules) { 153 return Unimplemented("LinkModules is not implemented."); 154 } 155 156 se::Platform::Id platform_id_; 157 158 // The triple that represents our target. 159 const char* target_triple_; 160 161 // The data layout of the emitted module. 162 const char* data_layout_; 163 164 // The size in bytes of a pointer. Used by ShapeSizeBytesFunction. 165 const int64_t pointer_size_; 166 167 GpuCompiler(const GpuCompiler&) = delete; 168 GpuCompiler& operator=(const GpuCompiler&) = delete; 169 }; 170 171 // Compile `hlo_module` using XLA GPU and return the LLVM module thus generated. 172 // The GpuExecutable (and the Thunks that are part of it) are not returned. 173 StatusOr<std::unique_ptr<llvm::Module>> CompileModuleToLlvmIr( 174 HloModule* hlo_module, llvm::LLVMContext* llvm_context, 175 const std::string& target_triple, const std::string& data_layout, 176 const std::string& platform_name, const se::Platform::Id platform_id, 177 GpuDeviceInfo gpu_device_info, 178 se::CudaComputeCapability cuda_compute_capability, 179 se::RocmComputeCapability rocm_compute_capability, int pointer_size); 180 181 // Compiles the given LMHLO module to an executable. 182 // ir_emitter_context should be partially populated: buffer_assignment 183 // or buffer_allocations should not be populated, while other fields should be 184 // populated (or left empty if that field is optional). 185 // 186 // NOTE: buffer_assignment will be gone from ir_emitter_context once LMHLO 187 // transition is done. 188 StatusOr<std::unique_ptr<Executable>> CompileLmhloToExecutable( 189 GpuCompiler* compiler, mlir::ModuleOp module, std::string module_name, 190 const HloModuleConfig& module_config, 191 const Compiler::CompileOptions& options, 192 absl::string_view entry_function_name, se::StreamExecutor* stream_exec, 193 std::unique_ptr<llvm::Module> llvm_module, 194 IrEmitterContext* ir_emitter_context); 195 196 } // namespace gpu 197 } // namespace xla 198 199 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_ 200