xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir/transforms/runtime/jit_compiler.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/mlir/transforms/runtime/jit_compiler.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <string_view>
22 #include <utility>
23 
24 #include "llvm/IR/PassTimingInfo.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Support/TargetSelect.h"
27 #include "mlir/ExecutionEngine/OptUtils.h"  // from @llvm-project
28 #include "mlir/Parser/Parser.h"  // from @llvm-project
29 #include "mlir/Pass/PassManager.h"  // from @llvm-project
30 #include "mlir/Target/LLVMIR/Export.h"  // from @llvm-project
31 #include "tensorflow/compiler/xla/mlir/transforms/runtime/passes.h"
32 #include "tensorflow/compiler/xla/runtime/symbolic_shape.h"
33 #include "tfrt/support/error_util.h"  // from @tf_runtime
34 
35 namespace xla {
36 namespace runtime {
37 
38 using namespace mlir;  // NOLINT
39 
40 using llvm::Error;
41 using tfrt::MakeStringError;
42 
DebugJitCompiler()43 static bool DebugJitCompiler() {
44 #if defined(DEBUG_XLA_RUNTIME_COMPILER)
45   return true;
46 #endif
47   return false;
48 }
49 
EnablePassTiming()50 static bool EnablePassTiming() {
51 #if defined(ENABLE_XLAR_RUNTIME_PASS_TIMING)
52   return true;
53 #endif
54   return DebugJitCompiler();
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // Setup MLIR pass pipeline to lower to LLVM dialect, and use ORC JIT to codegen
59 // functions at runtime.
60 //===----------------------------------------------------------------------===//
61 
InitializeCompiler()62 static void InitializeCompiler() {
63   static const bool initialized = ([] {
64     llvm::InitializeNativeTarget();
65     // Initialize asm printer and parser so that we can handle the inline
66     // assembly generated in MLIR for some operations.
67     llvm::InitializeNativeTargetAsmPrinter();
68     llvm::InitializeNativeTargetAsmParser();
69     return true;
70   })();
71   (void)initialized;
72 }
73 
SetupPassDebugging(MLIRContext * context,PassManager & pm)74 static void SetupPassDebugging(MLIRContext* context, PassManager& pm) {
75   // Print IR after all passes.
76   if (DebugJitCompiler()) {
77     context->disableMultithreading();
78     pm.enableIRPrinting([](Pass*, Operation*) { return false; },
79                         [](Pass*, Operation*) { return true; },
80                         /*printModuleScope=*/true,
81                         /*printAfterOnlyOnChange=*/false,
82                         /*printAfterOnlyOnFailure=*/false, llvm::errs());
83   }
84 }
85 
RunPipeline(ModuleOp module,const std::function<void (PassManager &)> & create_pipeline)86 static LogicalResult RunPipeline(
87     ModuleOp module, const std::function<void(PassManager&)>& create_pipeline) {
88   if (!create_pipeline) return success();
89 
90   PassManager pm(module.getContext());
91   SetupPassDebugging(module.getContext(), pm);
92 
93   // Instrument the pass manager to capture timing information.
94   DefaultTimingManager tm;
95   TimingScope timing;
96   if (EnablePassTiming()) {
97     tm.setEnabled(true);
98     timing = tm.getRootScope();
99     pm.enableTiming(timing);
100   }
101 
102   create_pipeline(pm);
103 
104   return pm.run(module);
105 }
106 
107 // Runs the user-provided compilation pipeline to compile the module to LLVM.
RunCompilationPipeline(ModuleOp module,const JitCompiler::Options & opts)108 static LogicalResult RunCompilationPipeline(ModuleOp module,
109                                             const JitCompiler::Options& opts) {
110   return RunPipeline(module, opts.create_compilation_pipeline);
111 }
112 
113 // Runs the user-provided specialization pipeline.
RunSpecializationPipeline(ModuleOp module,const JitCompiler::Options & opts)114 static LogicalResult RunSpecializationPipeline(
115     ModuleOp module, const JitCompiler::Options& opts) {
116   return RunPipeline(module, opts.create_specialization_pipeline);
117 }
118 
119 //===----------------------------------------------------------------------===//
120 
121 // Creates a new MLIR Context and registers all the dialects that are expected
122 // in the compiled module.
CreateMlirContext(const JitCompiler::Options & opts)123 static std::unique_ptr<MLIRContext> CreateMlirContext(
124     const JitCompiler::Options& opts) {
125   DialectRegistry registry;
126 
127   // Call user-provided callback to register all required dialects.
128   if (opts.register_dialects) opts.register_dialects(registry);
129 
130   auto threading = MLIRContext::Threading::DISABLED;
131   auto ctx = std::make_unique<MLIRContext>(registry, threading);
132   ctx->loadAllAvailableDialects();
133   return ctx;
134 }
135 
136 //===----------------------------------------------------------------------===//
137 // JitCompiler implementation.
138 //===----------------------------------------------------------------------===//
139 
JitCompiler(JitCompiler::Options opts,std::string_view mlir_module,std::string_view entrypoint)140 JitCompiler::JitCompiler(JitCompiler::Options opts,
141                          std::string_view mlir_module,
142                          std::string_view entrypoint)
143     : opts_(std::move(opts)),
144       context_(CreateMlirContext(opts_)),
145       diagnostic_os_(diagnostic_),
146       handler_(source_mgr_, context_.get(), diagnostic_os_),
147       specialized_(false) {
148   source_mgr_.AddNewSourceBuffer(
149       llvm::MemoryBuffer::getMemBuffer(mlir_module, "xla.program"),
150       llvm::SMLoc());
151 
152   module_ = parseSourceFile<ModuleOp>(source_mgr_, context_.get());
153   if (module_) entrypoint_ = module_->lookupSymbol<func::FuncOp>(entrypoint);
154 }
155 
156 /*static*/ llvm::Expected<std::unique_ptr<JitCompiler>>
Instantiate(JitCompiler::Options opts,std::string_view mlir_module,std::string_view entrypoint)157 JitCompiler::Instantiate(JitCompiler::Options opts,
158                          std::string_view mlir_module,
159                          std::string_view entrypoint) {
160   std::unique_ptr<JitCompiler> context(
161       new JitCompiler(std::move(opts), mlir_module, entrypoint));
162   if (!context->module_)
163     return context->Error("failed to parse the mlir source");
164   if (!context->entrypoint_)
165     return context->Error("failed to resolve entrypoint function");
166 
167   InitializeCompiler();
168 
169   return {std::move(context)};
170 }
171 
Compile(std::unique_ptr<JitCompiler> compiler,std::string_view memory_region_name,llvm::Optional<size_t> specialization)172 /*static*/ llvm::Expected<Executable> JitCompiler::Compile(
173     std::unique_ptr<JitCompiler> compiler, std::string_view memory_region_name,
174     llvm::Optional<size_t> specialization) {
175   const JitCompiler::Options& opts = compiler->options();
176   func::FuncOp entry_func = compiler->entrypoint();
177   std::string entrypoint = entry_func.getName().str();
178 
179   // We track end-to-end time to compile the final executable.
180   auto compilation_start = std::chrono::steady_clock::now();
181 
182   // Get the signature of the entrypoint function.
183   auto signature = opts.type_converter.Convert(entry_func.getFunctionType());
184   if (auto err = signature.takeError()) return std::move(err);
185 
186   // Get the calling convention for the entrypoint function.
187   if (!opts.calling_convention)
188     return compiler->Error("calling convention is not defined");
189 
190   // Calling convention conversion can fail if some types are not supported.
191   auto runtime_type = opts.calling_convention(entry_func.getFunctionType());
192   if (!runtime_type)
193     return compiler->Error(
194         "calling convention failed to convert entrypoint type");
195 
196   // Get the runtime signature of the entrypoint function.
197   auto runtime_signature = opts.type_converter.Convert(runtime_type);
198   if (auto err = runtime_signature.takeError()) return std::move(err);
199 
200   // Get the memory layout for passing function arguments.
201   auto arguments_memory_layout =
202       Executable::GetArgumentsMemoryLayout(*runtime_signature);
203   if (auto err = arguments_memory_layout.takeError()) return std::move(err);
204 
205   // Get the memory layout for returning function results.
206   auto results_memory_layout =
207       Executable::GetResultsMemoryLayout(*runtime_signature);
208   if (auto err = results_memory_layout.takeError()) return std::move(err);
209 
210   // Mark entry function with an attribute, so it can be converted to an Xla
211   // entrypoint (see `rt-convert-to-entrypoint` pass).
212   auto unit_attr = UnitAttr::get(entry_func.getContext());
213   entry_func->setAttr(kEntrypointAttrName, unit_attr);
214 
215   // Run the compilation pipeline to lower the module to LLVM dialect.
216   if (failed(RunCompilationPipeline(compiler->module(), opts)))
217     return compiler->Error("failed to run compilation pipeline");
218 
219   if (EnablePassTiming()) llvm::TimePassesIsEnabled = true;
220 
221   // Prepare JIT target machine for code generation.
222   auto builder = llvm::orc::JITTargetMachineBuilder::detectHost();
223   if (!builder) return builder.takeError();
224 
225   auto target_machine = builder->createTargetMachine();
226   if (!target_machine) return target_machine.takeError();
227 
228   // Name of the compiled module if available.
229   auto module_name = compiler->module().getSymName().value_or("<unknown>");
230 
231   // Memory region name to mmap executable code.
232   std::string mapper_name = llvm::formatv(
233       "/xla{0}{1}:@{2}::@{3}:{4}", memory_region_name.empty() ? "" : ":",
234       EscapeMemRegionName(memory_region_name), module_name, entrypoint,
235       specialization.has_value() ? "specialized" : "default");
236 
237   // Custom memory mapper to tag memory allocated for XLA executables.
238   std::unique_ptr<XlaRuntimeMemoryMapper> memory_mapper =
239       XlaRuntimeMemoryMapper::Create(std::move(mapper_name));
240 
241   // Register symbols required for running XLA Executable.
242   ExecutionEngine::SymbolsBinding symbols =
243       RuntimeSymbolsBinding(compiler->options().symbols_binding);
244 
245   // Construct options for the XLA runtime execution engine.
246   ExecutionEngine::JitOptions engine_options;
247   engine_options.opt_level = compiler->options().jit_code_opt_level;
248   engine_options.target_machine = target_machine->get();
249   engine_options.make_optimizing_transformer = makeOptimizingTransformer;
250   engine_options.section_memory_mapper = memory_mapper.get();
251   engine_options.symbols_binding = std::move(symbols);
252 
253   // Translate MLIR module to the LLVM module.
254   auto llvm_ctx = std::make_unique<llvm::LLVMContext>();
255   auto llvm_module = translateModuleToLLVMIR(compiler->module(), *llvm_ctx);
256   if (!llvm_module)
257     return MakeStringError("failed to translate module to LLVM IR");
258 
259   // Compile input module to the native function.
260   auto engine = ExecutionEngine::CreateFromModule(
261       std::move(llvm_ctx), std::move(llvm_module), entrypoint, engine_options);
262   if (auto err = engine.takeError()) return std::move(err);
263 
264   // At this point compilation is completed, and all symbols in the LLVM module
265   // materialized as addresses (entrypoint is an executable function pointer).
266   auto time_to_compile = std::chrono::duration_cast<std::chrono::milliseconds>(
267       std::chrono::steady_clock::now() - compilation_start);
268 
269   if (EnablePassTiming()) llvm::reportAndResetTimings();
270 
271   return Executable(
272       compiler->name().str(), std::move(memory_mapper), std::move(*engine),
273       std::move(*signature), std::move(*runtime_signature),
274       std::move(*arguments_memory_layout), std::move(*results_memory_layout),
275       specialization, time_to_compile);
276 }
277 
Specialize(ArgumentsRef arguments,ArrayRef<SymbolicShape> symbolic_shapes,ArrayRef<ArgumentConstraint> constraints,const SpecializationListener * listener)278 llvm::Error JitCompiler::Specialize(ArgumentsRef arguments,
279                                     ArrayRef<SymbolicShape> symbolic_shapes,
280                                     ArrayRef<ArgumentConstraint> constraints,
281                                     const SpecializationListener* listener) {
282   assert(!specialized_ && "can specialize executable only once");
283   specialized_ = true;
284 
285   func::FuncOp func = entrypoint();
286 
287   // Update function signature and sink constant arguments into the body.
288   if (auto err = SpecializeFunction(func, arguments, symbolic_shapes,
289                                     constraints, listener)) {
290     // No need to call this->Error() because we don't have diagnostic to report
291     // in case of a failed specialization.
292     return MakeStringError("failed to specialize: ", err);
293   }
294 
295   // Run the user-provided specialization pipeline to take advantage of the
296   // specialized operands and sunk constants.
297   if (failed(RunSpecializationPipeline(*module_, opts_)))
298     return Error("failed to run specialization pipeline");
299 
300   return Error::success();
301 }
302 
303 }  // namespace runtime
304 }  // namespace xla
305