1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/mlir/transforms/runtime/jit_compiler.h"
17
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <string_view>
22 #include <utility>
23
24 #include "llvm/IR/PassTimingInfo.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Support/TargetSelect.h"
27 #include "mlir/ExecutionEngine/OptUtils.h" // from @llvm-project
28 #include "mlir/Parser/Parser.h" // from @llvm-project
29 #include "mlir/Pass/PassManager.h" // from @llvm-project
30 #include "mlir/Target/LLVMIR/Export.h" // from @llvm-project
31 #include "tensorflow/compiler/xla/mlir/transforms/runtime/passes.h"
32 #include "tensorflow/compiler/xla/runtime/symbolic_shape.h"
33 #include "tfrt/support/error_util.h" // from @tf_runtime
34
35 namespace xla {
36 namespace runtime {
37
38 using namespace mlir; // NOLINT
39
40 using llvm::Error;
41 using tfrt::MakeStringError;
42
DebugJitCompiler()43 static bool DebugJitCompiler() {
44 #if defined(DEBUG_XLA_RUNTIME_COMPILER)
45 return true;
46 #endif
47 return false;
48 }
49
EnablePassTiming()50 static bool EnablePassTiming() {
51 #if defined(ENABLE_XLAR_RUNTIME_PASS_TIMING)
52 return true;
53 #endif
54 return DebugJitCompiler();
55 }
56
57 //===----------------------------------------------------------------------===//
58 // Setup MLIR pass pipeline to lower to LLVM dialect, and use ORC JIT to codegen
59 // functions at runtime.
60 //===----------------------------------------------------------------------===//
61
InitializeCompiler()62 static void InitializeCompiler() {
63 static const bool initialized = ([] {
64 llvm::InitializeNativeTarget();
65 // Initialize asm printer and parser so that we can handle the inline
66 // assembly generated in MLIR for some operations.
67 llvm::InitializeNativeTargetAsmPrinter();
68 llvm::InitializeNativeTargetAsmParser();
69 return true;
70 })();
71 (void)initialized;
72 }
73
SetupPassDebugging(MLIRContext * context,PassManager & pm)74 static void SetupPassDebugging(MLIRContext* context, PassManager& pm) {
75 // Print IR after all passes.
76 if (DebugJitCompiler()) {
77 context->disableMultithreading();
78 pm.enableIRPrinting([](Pass*, Operation*) { return false; },
79 [](Pass*, Operation*) { return true; },
80 /*printModuleScope=*/true,
81 /*printAfterOnlyOnChange=*/false,
82 /*printAfterOnlyOnFailure=*/false, llvm::errs());
83 }
84 }
85
RunPipeline(ModuleOp module,const std::function<void (PassManager &)> & create_pipeline)86 static LogicalResult RunPipeline(
87 ModuleOp module, const std::function<void(PassManager&)>& create_pipeline) {
88 if (!create_pipeline) return success();
89
90 PassManager pm(module.getContext());
91 SetupPassDebugging(module.getContext(), pm);
92
93 // Instrument the pass manager to capture timing information.
94 DefaultTimingManager tm;
95 TimingScope timing;
96 if (EnablePassTiming()) {
97 tm.setEnabled(true);
98 timing = tm.getRootScope();
99 pm.enableTiming(timing);
100 }
101
102 create_pipeline(pm);
103
104 return pm.run(module);
105 }
106
107 // Runs the user-provided compilation pipeline to compile the module to LLVM.
RunCompilationPipeline(ModuleOp module,const JitCompiler::Options & opts)108 static LogicalResult RunCompilationPipeline(ModuleOp module,
109 const JitCompiler::Options& opts) {
110 return RunPipeline(module, opts.create_compilation_pipeline);
111 }
112
113 // Runs the user-provided specialization pipeline.
RunSpecializationPipeline(ModuleOp module,const JitCompiler::Options & opts)114 static LogicalResult RunSpecializationPipeline(
115 ModuleOp module, const JitCompiler::Options& opts) {
116 return RunPipeline(module, opts.create_specialization_pipeline);
117 }
118
119 //===----------------------------------------------------------------------===//
120
121 // Creates a new MLIR Context and registers all the dialects that are expected
122 // in the compiled module.
CreateMlirContext(const JitCompiler::Options & opts)123 static std::unique_ptr<MLIRContext> CreateMlirContext(
124 const JitCompiler::Options& opts) {
125 DialectRegistry registry;
126
127 // Call user-provided callback to register all required dialects.
128 if (opts.register_dialects) opts.register_dialects(registry);
129
130 auto threading = MLIRContext::Threading::DISABLED;
131 auto ctx = std::make_unique<MLIRContext>(registry, threading);
132 ctx->loadAllAvailableDialects();
133 return ctx;
134 }
135
136 //===----------------------------------------------------------------------===//
137 // JitCompiler implementation.
138 //===----------------------------------------------------------------------===//
139
JitCompiler(JitCompiler::Options opts,std::string_view mlir_module,std::string_view entrypoint)140 JitCompiler::JitCompiler(JitCompiler::Options opts,
141 std::string_view mlir_module,
142 std::string_view entrypoint)
143 : opts_(std::move(opts)),
144 context_(CreateMlirContext(opts_)),
145 diagnostic_os_(diagnostic_),
146 handler_(source_mgr_, context_.get(), diagnostic_os_),
147 specialized_(false) {
148 source_mgr_.AddNewSourceBuffer(
149 llvm::MemoryBuffer::getMemBuffer(mlir_module, "xla.program"),
150 llvm::SMLoc());
151
152 module_ = parseSourceFile<ModuleOp>(source_mgr_, context_.get());
153 if (module_) entrypoint_ = module_->lookupSymbol<func::FuncOp>(entrypoint);
154 }
155
156 /*static*/ llvm::Expected<std::unique_ptr<JitCompiler>>
Instantiate(JitCompiler::Options opts,std::string_view mlir_module,std::string_view entrypoint)157 JitCompiler::Instantiate(JitCompiler::Options opts,
158 std::string_view mlir_module,
159 std::string_view entrypoint) {
160 std::unique_ptr<JitCompiler> context(
161 new JitCompiler(std::move(opts), mlir_module, entrypoint));
162 if (!context->module_)
163 return context->Error("failed to parse the mlir source");
164 if (!context->entrypoint_)
165 return context->Error("failed to resolve entrypoint function");
166
167 InitializeCompiler();
168
169 return {std::move(context)};
170 }
171
Compile(std::unique_ptr<JitCompiler> compiler,std::string_view memory_region_name,llvm::Optional<size_t> specialization)172 /*static*/ llvm::Expected<Executable> JitCompiler::Compile(
173 std::unique_ptr<JitCompiler> compiler, std::string_view memory_region_name,
174 llvm::Optional<size_t> specialization) {
175 const JitCompiler::Options& opts = compiler->options();
176 func::FuncOp entry_func = compiler->entrypoint();
177 std::string entrypoint = entry_func.getName().str();
178
179 // We track end-to-end time to compile the final executable.
180 auto compilation_start = std::chrono::steady_clock::now();
181
182 // Get the signature of the entrypoint function.
183 auto signature = opts.type_converter.Convert(entry_func.getFunctionType());
184 if (auto err = signature.takeError()) return std::move(err);
185
186 // Get the calling convention for the entrypoint function.
187 if (!opts.calling_convention)
188 return compiler->Error("calling convention is not defined");
189
190 // Calling convention conversion can fail if some types are not supported.
191 auto runtime_type = opts.calling_convention(entry_func.getFunctionType());
192 if (!runtime_type)
193 return compiler->Error(
194 "calling convention failed to convert entrypoint type");
195
196 // Get the runtime signature of the entrypoint function.
197 auto runtime_signature = opts.type_converter.Convert(runtime_type);
198 if (auto err = runtime_signature.takeError()) return std::move(err);
199
200 // Get the memory layout for passing function arguments.
201 auto arguments_memory_layout =
202 Executable::GetArgumentsMemoryLayout(*runtime_signature);
203 if (auto err = arguments_memory_layout.takeError()) return std::move(err);
204
205 // Get the memory layout for returning function results.
206 auto results_memory_layout =
207 Executable::GetResultsMemoryLayout(*runtime_signature);
208 if (auto err = results_memory_layout.takeError()) return std::move(err);
209
210 // Mark entry function with an attribute, so it can be converted to an Xla
211 // entrypoint (see `rt-convert-to-entrypoint` pass).
212 auto unit_attr = UnitAttr::get(entry_func.getContext());
213 entry_func->setAttr(kEntrypointAttrName, unit_attr);
214
215 // Run the compilation pipeline to lower the module to LLVM dialect.
216 if (failed(RunCompilationPipeline(compiler->module(), opts)))
217 return compiler->Error("failed to run compilation pipeline");
218
219 if (EnablePassTiming()) llvm::TimePassesIsEnabled = true;
220
221 // Prepare JIT target machine for code generation.
222 auto builder = llvm::orc::JITTargetMachineBuilder::detectHost();
223 if (!builder) return builder.takeError();
224
225 auto target_machine = builder->createTargetMachine();
226 if (!target_machine) return target_machine.takeError();
227
228 // Name of the compiled module if available.
229 auto module_name = compiler->module().getSymName().value_or("<unknown>");
230
231 // Memory region name to mmap executable code.
232 std::string mapper_name = llvm::formatv(
233 "/xla{0}{1}:@{2}::@{3}:{4}", memory_region_name.empty() ? "" : ":",
234 EscapeMemRegionName(memory_region_name), module_name, entrypoint,
235 specialization.has_value() ? "specialized" : "default");
236
237 // Custom memory mapper to tag memory allocated for XLA executables.
238 std::unique_ptr<XlaRuntimeMemoryMapper> memory_mapper =
239 XlaRuntimeMemoryMapper::Create(std::move(mapper_name));
240
241 // Register symbols required for running XLA Executable.
242 ExecutionEngine::SymbolsBinding symbols =
243 RuntimeSymbolsBinding(compiler->options().symbols_binding);
244
245 // Construct options for the XLA runtime execution engine.
246 ExecutionEngine::JitOptions engine_options;
247 engine_options.opt_level = compiler->options().jit_code_opt_level;
248 engine_options.target_machine = target_machine->get();
249 engine_options.make_optimizing_transformer = makeOptimizingTransformer;
250 engine_options.section_memory_mapper = memory_mapper.get();
251 engine_options.symbols_binding = std::move(symbols);
252
253 // Translate MLIR module to the LLVM module.
254 auto llvm_ctx = std::make_unique<llvm::LLVMContext>();
255 auto llvm_module = translateModuleToLLVMIR(compiler->module(), *llvm_ctx);
256 if (!llvm_module)
257 return MakeStringError("failed to translate module to LLVM IR");
258
259 // Compile input module to the native function.
260 auto engine = ExecutionEngine::CreateFromModule(
261 std::move(llvm_ctx), std::move(llvm_module), entrypoint, engine_options);
262 if (auto err = engine.takeError()) return std::move(err);
263
264 // At this point compilation is completed, and all symbols in the LLVM module
265 // materialized as addresses (entrypoint is an executable function pointer).
266 auto time_to_compile = std::chrono::duration_cast<std::chrono::milliseconds>(
267 std::chrono::steady_clock::now() - compilation_start);
268
269 if (EnablePassTiming()) llvm::reportAndResetTimings();
270
271 return Executable(
272 compiler->name().str(), std::move(memory_mapper), std::move(*engine),
273 std::move(*signature), std::move(*runtime_signature),
274 std::move(*arguments_memory_layout), std::move(*results_memory_layout),
275 specialization, time_to_compile);
276 }
277
Specialize(ArgumentsRef arguments,ArrayRef<SymbolicShape> symbolic_shapes,ArrayRef<ArgumentConstraint> constraints,const SpecializationListener * listener)278 llvm::Error JitCompiler::Specialize(ArgumentsRef arguments,
279 ArrayRef<SymbolicShape> symbolic_shapes,
280 ArrayRef<ArgumentConstraint> constraints,
281 const SpecializationListener* listener) {
282 assert(!specialized_ && "can specialize executable only once");
283 specialized_ = true;
284
285 func::FuncOp func = entrypoint();
286
287 // Update function signature and sink constant arguments into the body.
288 if (auto err = SpecializeFunction(func, arguments, symbolic_shapes,
289 constraints, listener)) {
290 // No need to call this->Error() because we don't have diagnostic to report
291 // in case of a failed specialization.
292 return MakeStringError("failed to specialize: ", err);
293 }
294
295 // Run the user-provided specialization pipeline to take advantage of the
296 // specialized operands and sunk constants.
297 if (failed(RunSpecializationPipeline(*module_, opts_)))
298 return Error("failed to run specialization pipeline");
299
300 return Error::success();
301 }
302
303 } // namespace runtime
304 } // namespace xla
305