1 #ifdef TORCH_ENABLE_LLVM
2
3 #include <c10/macros/Macros.h>
4
5 #include <torch/csrc/jit/tensorexpr/external_functions.h>
6 #include <torch/csrc/jit/tensorexpr/intrinsic_symbols.h>
7 #include <torch/csrc/jit/tensorexpr/llvm_jit.h>
8
9 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override")
10 #include <llvm/ExecutionEngine/ExecutionEngine.h>
11 #include <llvm/ExecutionEngine/JITSymbol.h>
12 C10_DIAGNOSTIC_POP()
13
14 #include <llvm/ExecutionEngine/Orc/CompileUtils.h>
15 #include <llvm/ExecutionEngine/Orc/ExecutionUtils.h>
16 #include <llvm/ExecutionEngine/Orc/IRCompileLayer.h>
17 // llvm::SCEVPredicate has virtual function but non-virtual destructor
18 // https://github.com/llvm/llvm-project/blob/c1a0a213378a458fbea1a5c77b315c7dce08fd05/llvm/include/llvm/Analysis/ScalarEvolution.h#L198
19 #pragma GCC diagnostic push
20 #pragma GCC diagnostic ignored "-Wnon-virtual-dtor"
21 #include <llvm/ExecutionEngine/Orc/LLJIT.h>
22 #pragma GCC diagnostic pop
23 #include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
24 #include <llvm/ExecutionEngine/Orc/SymbolStringPool.h>
25 #include <llvm/ExecutionEngine/RTDyldMemoryManager.h>
26 #include <llvm/ExecutionEngine/SectionMemoryManager.h>
27 #include <llvm/IR/DataLayout.h>
28 #include <llvm/IR/Mangler.h>
29 #include <llvm/Support/CFGUpdate.h>
30 #include <llvm/Support/DynamicLibrary.h>
31 #if LLVM_VERSION_MAJOR >= 18
32 #include <llvm/TargetParser/Host.h>
33 #else
34 #include <llvm/Support/Host.h>
35 #endif
36 #include <llvm/Support/raw_ostream.h>
37 #include <llvm/Target/TargetMachine.h>
38
39 #include <torch/csrc/jit/tensorexpr/external_functions_registry.h>
40
41 #include <c10/util/Half.h>
42
43 #include <algorithm>
44 #include <memory>
45 #include <string>
46 #include <unordered_set>
47 #include <vector>
48
49 using namespace torch::jit::tensorexpr;
50
51 template <typename T>
toAddress(T * Ptr)52 static llvm::JITTargetAddress toAddress(T* Ptr) {
53 return static_cast<llvm::JITTargetAddress>(reinterpret_cast<uintptr_t>(Ptr));
54 }
55
56 // Get subtarget features for the host.
getHostSubtargetFeatures()57 static llvm::SubtargetFeatures getHostSubtargetFeatures() {
58 llvm::SubtargetFeatures subtargetFeatures;
59 #if LLVM_VERSION_MAJOR >= 19
60 const auto featureMap = llvm::sys::getHostCPUFeatures();
61 #else
62 llvm::StringMap<bool> featureMap;
63 llvm::sys::getHostCPUFeatures(featureMap);
64 #endif
65 for (auto& feature : featureMap) {
66 subtargetFeatures.AddFeature(feature.first(), feature.second);
67 }
68 return subtargetFeatures;
69 }
70
71 // Create a JTMB using the host's triple. CPU and attrs default to the host
72 // unless they are supplied.
makeJTMBFromHost(std::optional<std::string> cpu,std::optional<std::string> attrs)73 static llvm::orc::JITTargetMachineBuilder makeJTMBFromHost(
74 std::optional<std::string> cpu,
75 std::optional<std::string> attrs) {
76 llvm::orc::JITTargetMachineBuilder JTMB(
77 (llvm::Triple(llvm::sys::getProcessTriple())));
78 JTMB.setCPU(cpu.value_or(llvm::sys::getHostCPUName().str()));
79 if (attrs) {
80 std::vector<std::string> features;
81 llvm::SubtargetFeatures::Split(features, *attrs);
82 JTMB.addFeatures(features);
83 } else {
84 JTMB.addFeatures(getHostSubtargetFeatures().getFeatures());
85 }
86 return JTMB;
87 }
88
89 // Create a JTMB using a given triple. Do not set cpu or attrs if not supplied.
makeJTMBFromTriple(const std::string & triple,std::optional<std::string> cpu,std::optional<std::string> attrs)90 static llvm::orc::JITTargetMachineBuilder makeJTMBFromTriple(
91 const std::string& triple,
92 std::optional<std::string> cpu,
93 std::optional<std::string> attrs) {
94 llvm::orc::JITTargetMachineBuilder JTMB((llvm::Triple(triple)));
95 if (cpu) {
96 JTMB.setCPU(*cpu);
97 }
98 if (attrs) {
99 std::vector<std::string> features;
100 llvm::SubtargetFeatures::Split(features, *attrs);
101 JTMB.addFeatures(features);
102 }
103 return JTMB;
104 }
105
makeTargetMachineBuilder(std::optional<std::string> triple,std::optional<std::string> cpu,std::optional<std::string> attrs)106 static llvm::orc::JITTargetMachineBuilder makeTargetMachineBuilder(
107 std::optional<std::string> triple,
108 std::optional<std::string> cpu,
109 std::optional<std::string> attrs) {
110 auto JTMB = triple ? makeJTMBFromTriple(*triple, cpu, attrs)
111 : makeJTMBFromHost(cpu, attrs);
112 #if LLVM_VERSION_MAJOR >= 18
113 JTMB.setCodeGenOptLevel(llvm::CodeGenOptLevel::Default);
114 #else
115 JTMB.setCodeGenOptLevel(llvm::CodeGenOpt::Default);
116 #endif
117 JTMB.getOptions().AllowFPOpFusion = llvm::FPOpFusion::Fast;
118 return JTMB;
119 }
120
registerIntrinsics(llvm::orc::JITDylib & JD,llvm::orc::MangleAndInterner & Mangle,std::unordered_set<std::string> & intrinsics)121 static void registerIntrinsics(
122 llvm::orc::JITDylib& JD,
123 llvm::orc::MangleAndInterner& Mangle,
124 std::unordered_set<std::string>& intrinsics) {
125 using namespace llvm;
126 using namespace llvm::orc;
127
128 auto entry = [&](const char* name, auto ptr) -> SymbolMap::value_type {
129 #if LLVM_VERSION_MAJOR >= 17
130 return {Mangle(name), {ExecutorAddr(toAddress(ptr)), JITSymbolFlags::None}};
131 #else
132 return {Mangle(name), {toAddress(ptr), JITSymbolFlags::None}};
133 #endif
134 };
135
136 SymbolMap symbols;
137 for (auto const& sym : getIntrinsicSymbols()) {
138 symbols.insert(entry(sym.symbol, sym.address));
139 intrinsics.insert(sym.symbol);
140 }
141 assertSuccess(JD.define(absoluteSymbols(symbols)));
142
143 for (auto& kv : getNNCFunctionRegistry()) {
144 assertSuccess(
145 JD.define(absoluteSymbols({entry(kv.first.c_str(), kv.second)})));
146 }
147 assertSuccess(JD.define(
148 absoluteSymbols({entry("DispatchParallel", DispatchParallel)})));
149 assertSuccess(
150 JD.define(absoluteSymbols({entry("nnc_aten_free", nnc_aten_free)})));
151 }
152
153 namespace llvm {
154 namespace orc {
155
156 // Lightly modified implementation from LLVM's Kaleidoscope JIT tutorial:
157 // https://llvm.org/docs/tutorial/BuildingAJIT1.html
158 #if LLVM_VERSION_MAJOR >= 9
159 class TORCH_API PytorchLLVMJITImpl {
160 private:
161 std::unique_ptr<TargetMachine> TM;
162 std::unique_ptr<LLJIT> LLJ;
163 std::unordered_set<std::string> intrinsics;
164
165 public:
PytorchLLVMJITImpl(std::optional<std::string> triple,std::optional<std::string> cpu,std::optional<std::string> attrs)166 PytorchLLVMJITImpl(
167 std::optional<std::string> triple,
168 std::optional<std::string> cpu,
169 std::optional<std::string> attrs)
170 : TM(assertSuccess(makeTargetMachineBuilder(triple, cpu, attrs)
171 .createTargetMachine())),
172 LLJ(assertSuccess(
173 LLJITBuilder()
174 .setJITTargetMachineBuilder(
175 makeTargetMachineBuilder(triple, cpu, attrs))
176 #if LLVM_VERSION_MAJOR >= 17
177 .setObjectLinkingLayerCreator([&](ExecutionSession& ES,
178 const Triple& TT) {
179 return std::make_unique<ObjectLinkingLayer>(
180 ES,
181 assertSuccess(jitlink::InProcessMemoryManager::Create()));
182 })
183 #endif
184 .create())) {
185 auto ProcSymbolsGenerator =
186 assertSuccess(DynamicLibrarySearchGenerator::GetForCurrentProcess(
187 LLJ->getDataLayout().getGlobalPrefix()));
188 auto& JD = LLJ->getMainJITDylib();
189 #if LLVM_VERSION_MAJOR == 9
190 JD.setGenerator(std::move(ProcSymbolsGenerator));
191 #else
192 JD.addGenerator(std::move(ProcSymbolsGenerator));
193 #endif
194
195 // Handle platform-specific symbol mangling
196 MangleAndInterner Mangle(LLJ->getExecutionSession(), LLJ->getDataLayout());
197
198 // Register implementations of intrinsics
199 registerIntrinsics(JD, Mangle, intrinsics);
200
201 // Work around UBSAN crashes which reads 8 byte in front of every function.
202 // Placing a dummy variable with 8 bytes first ensures there is readable
203 // memory before code for the first function is emitted. See also:
204 // - https://reviews.llvm.org/D148665
205 // - https://github.com/llvm/llvm-project/issues/65253
206 {
207 std::unique_ptr<llvm::LLVMContext> ctx =
208 std::make_unique<llvm::LLVMContext>();
209 std::unique_ptr<llvm::Module> module_ =
210 std::make_unique<llvm::Module>("__asan_workaround_fill", *ctx);
211 llvm::Type* type = llvm::ArrayType::get(llvm::Type::getInt8Ty(*ctx), 8);
__anon36bdda790302() 212 module_->getOrInsertGlobal("__asan_workaround_fill", type, [&]() {
213 return new llvm::GlobalVariable(
214 *module_,
215 type,
216 true,
217 llvm::GlobalVariable::InternalLinkage,
218 llvm::Constant::getNullValue(type),
219 "__asan_workaround_fill");
220 });
221 assertSuccess(LLJ->addIRModule(
222 ThreadSafeModule(std::move(module_), std::move(ctx))));
223 }
224 }
225
addModule(std::unique_ptr<Module> M,std::unique_ptr<LLVMContext> C)226 void addModule(std::unique_ptr<Module> M, std::unique_ptr<LLVMContext> C) {
227 assertSuccess(
228 LLJ->addIRModule(ThreadSafeModule(std::move(M), std::move(C))),
229 "Failed to add module to compile layer");
230 }
231
findSymbol(const std::string Name)232 JITSymbol findSymbol(const std::string Name) {
233 #if LLVM_VERSION_MAJOR >= 15
234 // Starting with llvm-15, LLJIT::lookup returns an address rather than a
235 // symbol. Even though an address is what we ultimately we want, we also
236 // want to avoid churning our internal APIs, so we wrap the returned address
237 // in a fake JITSymbol.
238 auto result = assertSuccess(LLJ->lookup(Name));
239 return JITSymbol(result.getValue(), JITSymbolFlags());
240 #else
241 return assertSuccess(LLJ->lookup(Name));
242 #endif
243 }
244
hasSymbol(const std::string & Name)245 bool hasSymbol(const std::string& Name) {
246 return intrinsics.find(Name) != intrinsics.end();
247 }
248
getTargetMachine()249 TargetMachine& getTargetMachine() {
250 return *TM;
251 }
252
getDataLayout()253 const DataLayout& getDataLayout() {
254 return LLJ->getDataLayout();
255 }
256 };
257
258 #elif LLVM_VERSION_MAJOR == 8 && LLVM_VERSION_PATCH == 20181009
259
260 class TORCH_API PytorchLLVMJITImpl {
261 private:
262 ExecutionSession ES;
263 std::shared_ptr<SymbolResolver> Resolver;
264 std::unique_ptr<TargetMachine> TM;
265 const DataLayout DL;
266 RTDyldObjectLinkingLayer ObjectLayer;
267 IRCompileLayer<decltype(ObjectLayer), SimpleCompiler> CompileLayer;
268 std::unordered_set<std::string> intrinsics;
269
270 public:
271 PytorchLLVMJITImpl(
272 std::optional<std::string> triple,
273 std::optional<std::string> cpu,
274 std::optional<std::string> attrs)
275 : Resolver(createLegacyLookupResolver(
276 ES,
277 [this](const std::string& Name) -> JITSymbol {
278 if (auto Sym = CompileLayer.findSymbol(Name, false)) {
279 return Sym;
280 } else if (auto Err = Sym.takeError()) {
281 return std::move(Err);
282 }
283 if (auto SymAddr =
284 RTDyldMemoryManager::getSymbolAddressInProcess(Name)) {
285 return JITSymbol(SymAddr, JITSymbolFlags::Exported);
286 }
287 MangleAndInterner Mangle(ES, DL);
288 return assertSuccess(
289 lookup({&ES.getMainJITDylib()}, Mangle(Name)));
290 },
291 [](Error Err) {
292 assertSuccess(std::move(Err), "lookupFlags failed");
293 })),
294 TM(assertSuccess(makeTargetMachineBuilder(triple, cpu, attrs)
295 .createTargetMachine())),
296 DL(TM->createDataLayout()),
297 ObjectLayer(
298 ES,
299 [this](VModuleKey) {
300 return RTDyldObjectLinkingLayer::Resources{
301 std::make_shared<SectionMemoryManager>(), Resolver};
302 }),
303 CompileLayer(ObjectLayer, SimpleCompiler(*TM)) {
304 auto& JD = ES.getMainJITDylib();
305 MangleAndInterner Mangle(ES, DL);
306 registerIntrinsics(JD, Mangle, intrinsics);
307 llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
308 }
309
310 TargetMachine& getTargetMachine() {
311 return *TM;
312 }
313
314 void addModule(std::unique_ptr<Module> M, std::unique_ptr<LLVMContext> C) {
315 // Add the module to the JIT with a new VModuleKey.
316 auto K = ES.allocateVModule();
317 assertSuccess(
318 CompileLayer.addModule(K, std::move(M)),
319 "Failed to add module to compile layer");
320 }
321
322 JITSymbol findSymbol(const std::string Name) {
323 std::string MangledName;
324 raw_string_ostream MangledNameStream(MangledName);
325 Mangler::getNameWithPrefix(MangledNameStream, Name, DL);
326 return CompileLayer.findSymbol(MangledNameStream.str(), true);
327 }
328
329 bool hasSymbol(const std::string& Name) {
330 return intrinsics.find(Name) != intrinsics.end();
331 }
332
333 JITTargetAddress getSymbolAddress(const std::string Name) {
334 return assertSuccess(findSymbol(Name).getAddress());
335 }
336
337 void removeModule(VModuleKey K) {
338 assertSuccess(CompileLayer.removeModule(K));
339 }
340
341 const DataLayout& getDataLayout() {
342 return DL;
343 }
344 };
345
346 #else // LLVM_VERSION_MAJOR
347 #error Only LLVM versions 8 and above are supported.
348 #endif
349
PytorchLLVMJIT(std::optional<std::string> triple,std::optional<std::string> cpu,std::optional<std::string> attrs)350 PytorchLLVMJIT::PytorchLLVMJIT(
351 std::optional<std::string> triple,
352 std::optional<std::string> cpu,
353 std::optional<std::string> attrs)
354 : impl_(std::make_unique<PytorchLLVMJITImpl>(triple, cpu, attrs)) {}
355
356 PytorchLLVMJIT::~PytorchLLVMJIT() = default;
357
addModule(std::unique_ptr<Module> M,std::unique_ptr<LLVMContext> C)358 void PytorchLLVMJIT::addModule(
359 std::unique_ptr<Module> M,
360 std::unique_ptr<LLVMContext> C) {
361 impl_->addModule(std::move(M), std::move(C));
362 }
363
findSymbol(const std::string Name)364 JITSymbol PytorchLLVMJIT::findSymbol(const std::string Name) {
365 return impl_->findSymbol(std::move(Name));
366 }
367
hasSymbol(const std::string & Name)368 bool PytorchLLVMJIT::hasSymbol(const std::string& Name) {
369 return impl_->hasSymbol(Name);
370 }
371
getTargetMachine()372 TargetMachine& PytorchLLVMJIT::getTargetMachine() {
373 return impl_->getTargetMachine();
374 }
375
getDataLayout()376 const DataLayout& PytorchLLVMJIT::getDataLayout() {
377 return impl_->getDataLayout();
378 }
379
380 #if !defined(NDEBUG)
dumpCFG(const llvm::cfg::Update<llvm::BasicBlock * > & update)381 void dumpCFG(const llvm::cfg::Update<llvm::BasicBlock*>& update) {
382 // XXX: This method call is only here to placate gcov builds. The `dump`
383 // method is conditionally defined when NDEBUG is unset, so if you try to
384 // link a debug-mode pytorch with an opt-mode llvm, the symbol is undefined.
385 update.dump();
386 }
387 #endif
388
389 } // end namespace orc
390 } // end namespace llvm
391
392 #endif // TORCH_ENABLE_LLVM
393