xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/llvm_jit.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifdef TORCH_ENABLE_LLVM
2 
3 #include <c10/macros/Macros.h>
4 
5 #include <torch/csrc/jit/tensorexpr/external_functions.h>
6 #include <torch/csrc/jit/tensorexpr/intrinsic_symbols.h>
7 #include <torch/csrc/jit/tensorexpr/llvm_jit.h>
8 
9 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override")
10 #include <llvm/ExecutionEngine/ExecutionEngine.h>
11 #include <llvm/ExecutionEngine/JITSymbol.h>
12 C10_DIAGNOSTIC_POP()
13 
14 #include <llvm/ExecutionEngine/Orc/CompileUtils.h>
15 #include <llvm/ExecutionEngine/Orc/ExecutionUtils.h>
16 #include <llvm/ExecutionEngine/Orc/IRCompileLayer.h>
17 // llvm::SCEVPredicate has virtual function but non-virtual destructor
18 // https://github.com/llvm/llvm-project/blob/c1a0a213378a458fbea1a5c77b315c7dce08fd05/llvm/include/llvm/Analysis/ScalarEvolution.h#L198
19 #pragma GCC diagnostic push
20 #pragma GCC diagnostic ignored "-Wnon-virtual-dtor"
21 #include <llvm/ExecutionEngine/Orc/LLJIT.h>
22 #pragma GCC diagnostic pop
23 #include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
24 #include <llvm/ExecutionEngine/Orc/SymbolStringPool.h>
25 #include <llvm/ExecutionEngine/RTDyldMemoryManager.h>
26 #include <llvm/ExecutionEngine/SectionMemoryManager.h>
27 #include <llvm/IR/DataLayout.h>
28 #include <llvm/IR/Mangler.h>
29 #include <llvm/Support/CFGUpdate.h>
30 #include <llvm/Support/DynamicLibrary.h>
31 #if LLVM_VERSION_MAJOR >= 18
32 #include <llvm/TargetParser/Host.h>
33 #else
34 #include <llvm/Support/Host.h>
35 #endif
36 #include <llvm/Support/raw_ostream.h>
37 #include <llvm/Target/TargetMachine.h>
38 
39 #include <torch/csrc/jit/tensorexpr/external_functions_registry.h>
40 
41 #include <c10/util/Half.h>
42 
43 #include <algorithm>
44 #include <memory>
45 #include <string>
46 #include <unordered_set>
47 #include <vector>
48 
49 using namespace torch::jit::tensorexpr;
50 
51 template <typename T>
toAddress(T * Ptr)52 static llvm::JITTargetAddress toAddress(T* Ptr) {
53   return static_cast<llvm::JITTargetAddress>(reinterpret_cast<uintptr_t>(Ptr));
54 }
55 
56 // Get subtarget features for the host.
getHostSubtargetFeatures()57 static llvm::SubtargetFeatures getHostSubtargetFeatures() {
58   llvm::SubtargetFeatures subtargetFeatures;
59 #if LLVM_VERSION_MAJOR >= 19
60   const auto featureMap = llvm::sys::getHostCPUFeatures();
61 #else
62   llvm::StringMap<bool> featureMap;
63   llvm::sys::getHostCPUFeatures(featureMap);
64 #endif
65   for (auto& feature : featureMap) {
66     subtargetFeatures.AddFeature(feature.first(), feature.second);
67   }
68   return subtargetFeatures;
69 }
70 
71 // Create a JTMB using the host's triple.  CPU and attrs default to the host
72 // unless they are supplied.
makeJTMBFromHost(std::optional<std::string> cpu,std::optional<std::string> attrs)73 static llvm::orc::JITTargetMachineBuilder makeJTMBFromHost(
74     std::optional<std::string> cpu,
75     std::optional<std::string> attrs) {
76   llvm::orc::JITTargetMachineBuilder JTMB(
77       (llvm::Triple(llvm::sys::getProcessTriple())));
78   JTMB.setCPU(cpu.value_or(llvm::sys::getHostCPUName().str()));
79   if (attrs) {
80     std::vector<std::string> features;
81     llvm::SubtargetFeatures::Split(features, *attrs);
82     JTMB.addFeatures(features);
83   } else {
84     JTMB.addFeatures(getHostSubtargetFeatures().getFeatures());
85   }
86   return JTMB;
87 }
88 
89 // Create a JTMB using a given triple.  Do not set cpu or attrs if not supplied.
makeJTMBFromTriple(const std::string & triple,std::optional<std::string> cpu,std::optional<std::string> attrs)90 static llvm::orc::JITTargetMachineBuilder makeJTMBFromTriple(
91     const std::string& triple,
92     std::optional<std::string> cpu,
93     std::optional<std::string> attrs) {
94   llvm::orc::JITTargetMachineBuilder JTMB((llvm::Triple(triple)));
95   if (cpu) {
96     JTMB.setCPU(*cpu);
97   }
98   if (attrs) {
99     std::vector<std::string> features;
100     llvm::SubtargetFeatures::Split(features, *attrs);
101     JTMB.addFeatures(features);
102   }
103   return JTMB;
104 }
105 
makeTargetMachineBuilder(std::optional<std::string> triple,std::optional<std::string> cpu,std::optional<std::string> attrs)106 static llvm::orc::JITTargetMachineBuilder makeTargetMachineBuilder(
107     std::optional<std::string> triple,
108     std::optional<std::string> cpu,
109     std::optional<std::string> attrs) {
110   auto JTMB = triple ? makeJTMBFromTriple(*triple, cpu, attrs)
111                      : makeJTMBFromHost(cpu, attrs);
112 #if LLVM_VERSION_MAJOR >= 18
113   JTMB.setCodeGenOptLevel(llvm::CodeGenOptLevel::Default);
114 #else
115   JTMB.setCodeGenOptLevel(llvm::CodeGenOpt::Default);
116 #endif
117   JTMB.getOptions().AllowFPOpFusion = llvm::FPOpFusion::Fast;
118   return JTMB;
119 }
120 
registerIntrinsics(llvm::orc::JITDylib & JD,llvm::orc::MangleAndInterner & Mangle,std::unordered_set<std::string> & intrinsics)121 static void registerIntrinsics(
122     llvm::orc::JITDylib& JD,
123     llvm::orc::MangleAndInterner& Mangle,
124     std::unordered_set<std::string>& intrinsics) {
125   using namespace llvm;
126   using namespace llvm::orc;
127 
128   auto entry = [&](const char* name, auto ptr) -> SymbolMap::value_type {
129 #if LLVM_VERSION_MAJOR >= 17
130     return {Mangle(name), {ExecutorAddr(toAddress(ptr)), JITSymbolFlags::None}};
131 #else
132     return {Mangle(name), {toAddress(ptr), JITSymbolFlags::None}};
133 #endif
134   };
135 
136   SymbolMap symbols;
137   for (auto const& sym : getIntrinsicSymbols()) {
138     symbols.insert(entry(sym.symbol, sym.address));
139     intrinsics.insert(sym.symbol);
140   }
141   assertSuccess(JD.define(absoluteSymbols(symbols)));
142 
143   for (auto& kv : getNNCFunctionRegistry()) {
144     assertSuccess(
145         JD.define(absoluteSymbols({entry(kv.first.c_str(), kv.second)})));
146   }
147   assertSuccess(JD.define(
148       absoluteSymbols({entry("DispatchParallel", DispatchParallel)})));
149   assertSuccess(
150       JD.define(absoluteSymbols({entry("nnc_aten_free", nnc_aten_free)})));
151 }
152 
153 namespace llvm {
154 namespace orc {
155 
156 // Lightly modified implementation from LLVM's Kaleidoscope JIT tutorial:
157 // https://llvm.org/docs/tutorial/BuildingAJIT1.html
158 #if LLVM_VERSION_MAJOR >= 9
159 class TORCH_API PytorchLLVMJITImpl {
160  private:
161   std::unique_ptr<TargetMachine> TM;
162   std::unique_ptr<LLJIT> LLJ;
163   std::unordered_set<std::string> intrinsics;
164 
165  public:
PytorchLLVMJITImpl(std::optional<std::string> triple,std::optional<std::string> cpu,std::optional<std::string> attrs)166   PytorchLLVMJITImpl(
167       std::optional<std::string> triple,
168       std::optional<std::string> cpu,
169       std::optional<std::string> attrs)
170       : TM(assertSuccess(makeTargetMachineBuilder(triple, cpu, attrs)
171                              .createTargetMachine())),
172         LLJ(assertSuccess(
173             LLJITBuilder()
174                 .setJITTargetMachineBuilder(
175                     makeTargetMachineBuilder(triple, cpu, attrs))
176 #if LLVM_VERSION_MAJOR >= 17
177                 .setObjectLinkingLayerCreator([&](ExecutionSession& ES,
178                                                   const Triple& TT) {
179                   return std::make_unique<ObjectLinkingLayer>(
180                       ES,
181                       assertSuccess(jitlink::InProcessMemoryManager::Create()));
182                 })
183 #endif
184                 .create())) {
185     auto ProcSymbolsGenerator =
186         assertSuccess(DynamicLibrarySearchGenerator::GetForCurrentProcess(
187             LLJ->getDataLayout().getGlobalPrefix()));
188     auto& JD = LLJ->getMainJITDylib();
189 #if LLVM_VERSION_MAJOR == 9
190     JD.setGenerator(std::move(ProcSymbolsGenerator));
191 #else
192     JD.addGenerator(std::move(ProcSymbolsGenerator));
193 #endif
194 
195     // Handle platform-specific symbol mangling
196     MangleAndInterner Mangle(LLJ->getExecutionSession(), LLJ->getDataLayout());
197 
198     // Register implementations of intrinsics
199     registerIntrinsics(JD, Mangle, intrinsics);
200 
201     // Work around UBSAN crashes which reads 8 byte in front of every function.
202     // Placing a dummy variable with 8 bytes first ensures there is readable
203     // memory before code for the first function is emitted. See also:
204     // - https://reviews.llvm.org/D148665
205     // - https://github.com/llvm/llvm-project/issues/65253
206     {
207       std::unique_ptr<llvm::LLVMContext> ctx =
208           std::make_unique<llvm::LLVMContext>();
209       std::unique_ptr<llvm::Module> module_ =
210           std::make_unique<llvm::Module>("__asan_workaround_fill", *ctx);
211       llvm::Type* type = llvm::ArrayType::get(llvm::Type::getInt8Ty(*ctx), 8);
__anon36bdda790302() 212       module_->getOrInsertGlobal("__asan_workaround_fill", type, [&]() {
213         return new llvm::GlobalVariable(
214             *module_,
215             type,
216             true,
217             llvm::GlobalVariable::InternalLinkage,
218             llvm::Constant::getNullValue(type),
219             "__asan_workaround_fill");
220       });
221       assertSuccess(LLJ->addIRModule(
222           ThreadSafeModule(std::move(module_), std::move(ctx))));
223     }
224   }
225 
addModule(std::unique_ptr<Module> M,std::unique_ptr<LLVMContext> C)226   void addModule(std::unique_ptr<Module> M, std::unique_ptr<LLVMContext> C) {
227     assertSuccess(
228         LLJ->addIRModule(ThreadSafeModule(std::move(M), std::move(C))),
229         "Failed to add module to compile layer");
230   }
231 
findSymbol(const std::string Name)232   JITSymbol findSymbol(const std::string Name) {
233 #if LLVM_VERSION_MAJOR >= 15
234     // Starting with llvm-15, LLJIT::lookup returns an address rather than a
235     // symbol. Even though an address is what we ultimately we want, we also
236     // want to avoid churning our internal APIs, so we wrap the returned address
237     // in a fake JITSymbol.
238     auto result = assertSuccess(LLJ->lookup(Name));
239     return JITSymbol(result.getValue(), JITSymbolFlags());
240 #else
241     return assertSuccess(LLJ->lookup(Name));
242 #endif
243   }
244 
hasSymbol(const std::string & Name)245   bool hasSymbol(const std::string& Name) {
246     return intrinsics.find(Name) != intrinsics.end();
247   }
248 
getTargetMachine()249   TargetMachine& getTargetMachine() {
250     return *TM;
251   }
252 
getDataLayout()253   const DataLayout& getDataLayout() {
254     return LLJ->getDataLayout();
255   }
256 };
257 
258 #elif LLVM_VERSION_MAJOR == 8 && LLVM_VERSION_PATCH == 20181009
259 
260 class TORCH_API PytorchLLVMJITImpl {
261  private:
262   ExecutionSession ES;
263   std::shared_ptr<SymbolResolver> Resolver;
264   std::unique_ptr<TargetMachine> TM;
265   const DataLayout DL;
266   RTDyldObjectLinkingLayer ObjectLayer;
267   IRCompileLayer<decltype(ObjectLayer), SimpleCompiler> CompileLayer;
268   std::unordered_set<std::string> intrinsics;
269 
270  public:
271   PytorchLLVMJITImpl(
272       std::optional<std::string> triple,
273       std::optional<std::string> cpu,
274       std::optional<std::string> attrs)
275       : Resolver(createLegacyLookupResolver(
276             ES,
277             [this](const std::string& Name) -> JITSymbol {
278               if (auto Sym = CompileLayer.findSymbol(Name, false)) {
279                 return Sym;
280               } else if (auto Err = Sym.takeError()) {
281                 return std::move(Err);
282               }
283               if (auto SymAddr =
284                       RTDyldMemoryManager::getSymbolAddressInProcess(Name)) {
285                 return JITSymbol(SymAddr, JITSymbolFlags::Exported);
286               }
287               MangleAndInterner Mangle(ES, DL);
288               return assertSuccess(
289                   lookup({&ES.getMainJITDylib()}, Mangle(Name)));
290             },
291             [](Error Err) {
292               assertSuccess(std::move(Err), "lookupFlags failed");
293             })),
294         TM(assertSuccess(makeTargetMachineBuilder(triple, cpu, attrs)
295                              .createTargetMachine())),
296         DL(TM->createDataLayout()),
297         ObjectLayer(
298             ES,
299             [this](VModuleKey) {
300               return RTDyldObjectLinkingLayer::Resources{
301                   std::make_shared<SectionMemoryManager>(), Resolver};
302             }),
303         CompileLayer(ObjectLayer, SimpleCompiler(*TM)) {
304     auto& JD = ES.getMainJITDylib();
305     MangleAndInterner Mangle(ES, DL);
306     registerIntrinsics(JD, Mangle, intrinsics);
307     llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
308   }
309 
310   TargetMachine& getTargetMachine() {
311     return *TM;
312   }
313 
314   void addModule(std::unique_ptr<Module> M, std::unique_ptr<LLVMContext> C) {
315     // Add the module to the JIT with a new VModuleKey.
316     auto K = ES.allocateVModule();
317     assertSuccess(
318         CompileLayer.addModule(K, std::move(M)),
319         "Failed to add module to compile layer");
320   }
321 
322   JITSymbol findSymbol(const std::string Name) {
323     std::string MangledName;
324     raw_string_ostream MangledNameStream(MangledName);
325     Mangler::getNameWithPrefix(MangledNameStream, Name, DL);
326     return CompileLayer.findSymbol(MangledNameStream.str(), true);
327   }
328 
329   bool hasSymbol(const std::string& Name) {
330     return intrinsics.find(Name) != intrinsics.end();
331   }
332 
333   JITTargetAddress getSymbolAddress(const std::string Name) {
334     return assertSuccess(findSymbol(Name).getAddress());
335   }
336 
337   void removeModule(VModuleKey K) {
338     assertSuccess(CompileLayer.removeModule(K));
339   }
340 
341   const DataLayout& getDataLayout() {
342     return DL;
343   }
344 };
345 
346 #else // LLVM_VERSION_MAJOR
347 #error Only LLVM versions 8 and above are supported.
348 #endif
349 
PytorchLLVMJIT(std::optional<std::string> triple,std::optional<std::string> cpu,std::optional<std::string> attrs)350 PytorchLLVMJIT::PytorchLLVMJIT(
351     std::optional<std::string> triple,
352     std::optional<std::string> cpu,
353     std::optional<std::string> attrs)
354     : impl_(std::make_unique<PytorchLLVMJITImpl>(triple, cpu, attrs)) {}
355 
356 PytorchLLVMJIT::~PytorchLLVMJIT() = default;
357 
addModule(std::unique_ptr<Module> M,std::unique_ptr<LLVMContext> C)358 void PytorchLLVMJIT::addModule(
359     std::unique_ptr<Module> M,
360     std::unique_ptr<LLVMContext> C) {
361   impl_->addModule(std::move(M), std::move(C));
362 }
363 
findSymbol(const std::string Name)364 JITSymbol PytorchLLVMJIT::findSymbol(const std::string Name) {
365   return impl_->findSymbol(std::move(Name));
366 }
367 
hasSymbol(const std::string & Name)368 bool PytorchLLVMJIT::hasSymbol(const std::string& Name) {
369   return impl_->hasSymbol(Name);
370 }
371 
getTargetMachine()372 TargetMachine& PytorchLLVMJIT::getTargetMachine() {
373   return impl_->getTargetMachine();
374 }
375 
getDataLayout()376 const DataLayout& PytorchLLVMJIT::getDataLayout() {
377   return impl_->getDataLayout();
378 }
379 
380 #if !defined(NDEBUG)
dumpCFG(const llvm::cfg::Update<llvm::BasicBlock * > & update)381 void dumpCFG(const llvm::cfg::Update<llvm::BasicBlock*>& update) {
382   // XXX: This method call is only here to placate gcov builds.  The `dump`
383   // method is conditionally defined when NDEBUG is unset, so if you try to
384   // link a debug-mode pytorch with an opt-mode llvm, the symbol is undefined.
385   update.dump();
386 }
387 #endif
388 
389 } // end namespace orc
390 } // end namespace llvm
391 
392 #endif // TORCH_ENABLE_LLVM
393