1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SIMPLE_ORC_JIT_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SIMPLE_ORC_JIT_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "llvm/ADT/Triple.h" 24 #include "llvm/ExecutionEngine/JITEventListener.h" 25 #include "llvm/ExecutionEngine/Orc/Core.h" 26 #include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" 27 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" 28 #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" 29 #include "llvm/ExecutionEngine/Orc/SymbolStringPool.h" 30 #include "llvm/IR/Module.h" 31 #include "llvm/Target/TargetMachine.h" 32 #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" 33 #include "tensorflow/compiler/xla/types.h" 34 35 namespace xla { 36 namespace cpu { 37 38 // Simplified LLVM JIT based on the new Orc API. 39 // 40 // This class wraps Orc's functionality into a single interface that only 41 // exposes what we need for XLA. 42 // 43 // Supports JIT-ing multiple modules but without cross-module linking. 44 // Implements eager compilation - the module is lowered to binary as soon as 45 // it's added to the JIT. 46 class SimpleOrcJIT : public llvm::JITEventListener { 47 public: 48 using ObjLayerT = llvm::orc::RTDyldObjectLinkingLayer; 49 using CompileLayerT = llvm::orc::IRCompileLayer; 50 51 // Create a new JIT, targeting the host architecture. 52 // 53 // {pre,post}_optimization_hook is invoked on the module before/after all 54 // LLVM IR-level optimizations. post_codegen_hook is invoked after 55 // compiling to machine code. 56 SimpleOrcJIT( 57 std::unique_ptr<llvm::orc::ExecutorProcessControl> target_process_control, 58 std::unique_ptr<llvm::orc::ExecutionSession> execution_session, 59 const llvm::TargetOptions& target_options, 60 llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, 61 bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags, 62 LLVMCompiler::ModuleHook pre_optimization_hook, 63 LLVMCompiler::ModuleHook post_optimization_hook, 64 std::function<void(const llvm::object::ObjectFile&)> post_codegen_hook); 65 66 static llvm::Expected<std::unique_ptr<SimpleOrcJIT>> Create( 67 const llvm::TargetOptions& target_options, 68 llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, 69 bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags, 70 LLVMCompiler::ModuleHook pre_optimization_hook, 71 LLVMCompiler::ModuleHook post_optimization_hook, 72 std::function<void(const llvm::object::ObjectFile&)> post_codegen_hook); 73 74 ~SimpleOrcJIT() override; 75 data_layout()76 const llvm::DataLayout& data_layout() const { return data_layout_; } 77 target_triple()78 const llvm::Triple& target_triple() const { return target_triple_; } 79 80 llvm::Error AddModule(llvm::orc::ThreadSafeModule module); 81 82 // Discards objects we no longer need once we are done compiling. 83 void DoneCompiling(); 84 85 // Get the runtime address of the compiled symbol whose name is given. Returns 86 // nullptr if the symbol cannot be found. 87 llvm::Expected<llvm::JITEvaluatedSymbol> FindCompiledSymbol( 88 const std::string& name); 89 target_machine()90 llvm::TargetMachine* target_machine() const { return target_machine_.get(); } 91 92 // Creates an llvm::TargetMachine suitable for JITting code that will run on 93 // the current machine. 94 static std::unique_ptr<llvm::TargetMachine> InferTargetMachineForJIT( 95 const llvm::TargetOptions& target_options, 96 llvm::CodeGenOpt::Level opt_level); 97 SizeOfGeneratedCodeInBytes()98 int64_t SizeOfGeneratedCodeInBytes() const { 99 return size_of_generated_code_in_bytes_; 100 } 101 102 private: 103 llvm::JITEvaluatedSymbol ResolveRuntimeSymbol(llvm::StringRef name); 104 105 void notifyObjectLoaded( 106 llvm::JITEventListener::ObjectKey key, 107 const llvm::object::ObjectFile& object, 108 const llvm::RuntimeDyld::LoadedObjectInfo& object_info) override; 109 void notifyFreeingObject(llvm::JITEventListener::ObjectKey key) override; 110 111 std::unique_ptr<llvm::TargetMachine> target_machine_; 112 llvm::Triple target_triple_; 113 const llvm::DataLayout data_layout_; 114 std::unique_ptr<llvm::orc::ExecutorProcessControl> target_process_control_; 115 std::unique_ptr<llvm::orc::ExecutionSession> execution_session_; 116 ObjLayerT object_layer_; 117 CompileLayerT compile_layer_; 118 llvm::orc::JITDylib* main_jit_dylib_; 119 int64_t size_of_generated_code_in_bytes_ = 0; 120 121 // Non owning pointer to a JIT event listener that registers the JIT events 122 // with an attached GDB. 123 // 124 // Note: we get a pointer to this event listener using 125 // `createGDBRegistrationListener` which makes it look like we're supposed to 126 // free this, but the function is poorly named and really just returns a 127 // pointer to a static object. 128 llvm::JITEventListener* gdb_jit_event_listener_; 129 }; 130 131 } // namespace cpu 132 } // namespace xla 133 134 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SIMPLE_ORC_JIT_H_ 135