xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/llvm_codegen.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifdef TORCH_ENABLE_LLVM
2 
3 #include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
4 
5 #include <ATen/NativeFunctions.h>
6 #include <ATen/Parallel.h>
7 #include <c10/util/Exception.h>
8 #include <c10/util/irange.h>
9 #include <torch/csrc/jit/tensorexpr/analysis.h>
10 #include <torch/csrc/jit/tensorexpr/llvm_jit.h>
11 
12 // Note [llvm::SCEVPredicate non-virtual destructor]
13 // llvm::SCEVPredicate has virtual function but non-virtual destructor
14 // https://github.com/llvm/llvm-project/blob/c1a0a213378a458fbea1a5c77b315c7dce08fd05/llvm/include/llvm/Analysis/ScalarEvolution.h#L198
15 #pragma GCC diagnostic push
16 #pragma GCC diagnostic ignored "-Wnon-virtual-dtor"
17 #include <llvm/Analysis/TargetTransformInfo.h>
18 #pragma GCC diagnostic pop
19 
20 #include <llvm/Analysis/CGSCCPassManager.h>
21 #include <llvm/Analysis/LoopAnalysisManager.h>
22 #include <llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h>
23 #include <llvm/ExecutionEngine/Orc/ThreadSafeModule.h>
24 #include <llvm/IR/IRBuilder.h>
25 #include <llvm/IR/LegacyPassManager.h>
26 #include <llvm/IR/MDBuilder.h>
27 #include <llvm/IR/PassManager.h>
28 #include <llvm/IR/Verifier.h>
29 #include <llvm/MC/MCSubtargetInfo.h>
30 #include <llvm/Pass.h>
31 
32 // see Note [llvm::SCEVPredicate non-virtual destructor]
33 #pragma GCC diagnostic push
34 #pragma GCC diagnostic ignored "-Wnon-virtual-dtor"
35 #include <llvm/Passes/PassBuilder.h>
36 #pragma GCC diagnostic pop
37 
38 #if LLVM_VERSION_MAJOR >= 18
39 #include <llvm/TargetParser/Host.h>
40 #else
41 #include <llvm/Support/Host.h>
42 #endif
43 #include <llvm/Support/TargetSelect.h>
44 #include <llvm/Transforms/IPO/AlwaysInliner.h>
45 #include <llvm/Transforms/Scalar/DCE.h>
46 #include <llvm/Transforms/Vectorize/LoopVectorize.h>
47 #include <llvm/Transforms/Vectorize/SLPVectorizer.h>
48 
49 #if LLVM_VERSION_MAJOR >= 10
50 #include <llvm/Support/CodeGen.h>
51 #else
52 #include <llvm/Target/TargetMachine.h>
53 #endif
54 
55 #if LLVM_VERSION_MAJOR >= 11
56 #include <llvm/Support/TypeSize.h>
57 #endif
58 
59 #if LLVM_VERSION_MAJOR < 15
60 #include <llvm/Transforms/IPO/PassManagerBuilder.h>
61 #endif
62 
63 #include <llvm/Transforms/IPO/AlwaysInliner.h>
64 #include <llvm/Transforms/Scalar.h>
65 
66 #include <torch/csrc/jit/tensorexpr/expr.h>
67 #include <torch/csrc/jit/tensorexpr/external_functions_registry.h>
68 #include <torch/csrc/jit/tensorexpr/half_support.h>
69 #include <torch/csrc/jit/tensorexpr/ir.h>
70 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
71 #include <torch/csrc/jit/tensorexpr/tensor.h>
72 #include <torch/csrc/jit/tensorexpr/types.h>
73 
74 #include <torch/csrc/jit/jit_log.h>
75 
76 #include <memory>
77 
78 using namespace torch::jit::tensorexpr;
79 
80 C10_DEFINE_bool(
81     torch_jit_llvm_use_fast_intrinsics,
82     false,
83     "Use fast (but slightly less accurate) implementations of tanh and sigmoid");
84 
85 namespace torch::jit::tensorexpr {
86 
LLVMTargetTriple()87 std::optional<std::string>& LLVMTargetTriple() {
88   static std::optional<std::string> triple = std::nullopt;
89   return triple;
90 }
LLVMTargetCPU()91 std::optional<std::string>& LLVMTargetCPU() {
92   static std::optional<std::string> cpu = std::nullopt;
93   return cpu;
94 }
LLVMTargetAttrs()95 std::optional<std::string>& LLVMTargetAttrs() {
96   static std::optional<std::string> attrs = std::nullopt;
97   return attrs;
98 }
LLVMAOTWorkflow()99 bool& LLVMAOTWorkflow() {
100   static bool aot_workflow = false;
101   return aot_workflow;
102 }
103 
104 namespace {
105 
106 #if LLVM_VERSION_MAJOR >= 15
107 // Address and type pair to assist in handling of opaque pointers.
108 struct TypedPointer {
109   TypedPointer() = default;
TypedPointertorch::jit::tensorexpr::__anonc5f37dc70111::TypedPointer110   TypedPointer(llvm::Type* t, llvm::Value* a) : type(t), addr(a) {}
111   llvm::Type* type = nullptr;
112   llvm::Value* addr = nullptr;
113 };
114 #endif
115 
llvm_comparison_predicate(CompareSelectOperation compare_op,const ScalarType & type)116 llvm::CmpInst::Predicate llvm_comparison_predicate(
117     CompareSelectOperation compare_op,
118     const ScalarType& type) {
119   switch (compare_op) {
120     case CompareSelectOperation::kEQ:
121       return llvm::ICmpInst::ICMP_EQ;
122     case CompareSelectOperation::kNE:
123       return llvm::ICmpInst::ICMP_NE;
124     case CompareSelectOperation::kGT:
125       return c10::isSignedType(type) ? llvm::ICmpInst::ICMP_SGT
126                                      : llvm::ICmpInst::ICMP_UGT;
127     case CompareSelectOperation::kGE:
128       return c10::isSignedType(type) ? llvm::ICmpInst::ICMP_SGE
129                                      : llvm::ICmpInst::ICMP_UGE;
130     case CompareSelectOperation::kLT:
131       return c10::isSignedType(type) ? llvm::ICmpInst::ICMP_SLT
132                                      : llvm::ICmpInst::ICMP_ULT;
133     case CompareSelectOperation::kLE:
134       return c10::isSignedType(type) ? llvm::ICmpInst::ICMP_SLE
135                                      : llvm::ICmpInst::ICMP_ULE;
136     default:
137       // TODO: change to a proper error report
138       throw std::runtime_error("invalid operator type");
139   }
140 }
141 
llvm_fp_comparison_predicate(CompareSelectOperation compare_op)142 llvm::CmpInst::Predicate llvm_fp_comparison_predicate(
143     CompareSelectOperation compare_op) {
144   switch (compare_op) {
145     case CompareSelectOperation::kEQ:
146       return llvm::FCmpInst::FCMP_OEQ;
147     case CompareSelectOperation::kNE:
148       return llvm::FCmpInst::FCMP_ONE;
149     case CompareSelectOperation::kGT:
150       return llvm::FCmpInst::FCMP_OGT;
151     case CompareSelectOperation::kGE:
152       return llvm::FCmpInst::FCMP_OGE;
153     case CompareSelectOperation::kLT:
154       return llvm::FCmpInst::FCMP_OLT;
155     case CompareSelectOperation::kLE:
156       return llvm::FCmpInst::FCMP_OLE;
157     default:
158       // TODO: change to a proper error report
159       throw std::runtime_error("invalid operator type");
160   }
161 }
162 
163 #if LLVM_VERSION_MAJOR <= 9
ElementCount(int lanes)164 int ElementCount(int lanes) {
165   return lanes;
166 }
167 #else
ElementCount(int lanes)168 llvm::ElementCount ElementCount(int lanes) {
169 #if LLVM_VERSION_MAJOR <= 11
170   return llvm::ElementCount(static_cast<unsigned>(lanes), false);
171 #elif LLVM_VERSION_MAJOR >= 12
172   return llvm::ElementCount::getFixed(lanes);
173 #else
174 #error Only LLVM versions 8 and above are supported.
175 #endif
176 }
177 #endif
178 
179 #if LLVM_VERSION_MAJOR >= 9
180 
181 using FunctionCallee = llvm::FunctionCallee;
182 
183 #elif LLVM_VERSION_MAJOR == 8 && LLVM_VERSION_PATCH == 20181009
184 
185 struct FunctionCallee {
FunctionCalleetorch::jit::tensorexpr::__anonc5f37dc70111::FunctionCallee186   FunctionCallee() {}
187 
FunctionCalleetorch::jit::tensorexpr::__anonc5f37dc70111::FunctionCallee188   FunctionCallee(llvm::Constant* fn)
189       : v_(fn), ft_(cast<llvm::Function>(v_)->getFunctionType()) {}
190 
getFunctionTypetorch::jit::tensorexpr::__anonc5f37dc70111::FunctionCallee191   llvm::FunctionType* getFunctionType() {
192     return ft_;
193   }
194 
getCalleetorch::jit::tensorexpr::__anonc5f37dc70111::FunctionCallee195   llvm::Value* getCallee() {
196     return v_;
197   }
198 
199  private:
200   llvm::Value* v_{nullptr};
201   llvm::FunctionType* ft_{nullptr};
202 };
203 
204 #else
205 #error Only LLVM versions 8 and above are supported.
206 #endif
207 } // namespace
208 
209 class LLVMCodeGenCallee {
210  public:
LLVMCodeGenCallee(std::unique_ptr<llvm::orc::PytorchLLVMJIT> jit,void * kernelAddress)211   LLVMCodeGenCallee(
212       std::unique_ptr<llvm::orc::PytorchLLVMJIT> jit,
213       void* kernelAddress)
214       : jit_(std::move(jit)), kernelAddress_(kernelAddress) {}
215 
getJIT()216   llvm::orc::PytorchLLVMJIT* getJIT() {
217     return jit_.get();
218   }
219 
getKernelAddress()220   void* getKernelAddress() {
221     return kernelAddress_;
222   }
223 
setKernelAddress(void * kernelAddress)224   void setKernelAddress(void* kernelAddress) {
225     kernelAddress_ = kernelAddress;
226   }
227 
228  private:
229   std::unique_ptr<llvm::orc::PytorchLLVMJIT> jit_;
230   void* kernelAddress_;
231 };
232 
233 class LLVMCodeGenImpl : public IRVisitor {
234  private:
235   std::unique_ptr<llvm::LLVMContext> context_;
236   llvm::IRBuilder<> irb_;
237   std::unique_ptr<llvm::orc::PytorchLLVMJIT> jit_;
238   std::unique_ptr<llvm::Module> module_;
239   llvm::Function* fn_;
240   llvm::BasicBlock* bb_;
241   llvm::Value* value_{nullptr};
242   llvm::JITTargetAddress kernelAddress_;
243   std::string kernel_func_name_;
244 
245 #define LLVM_TYPE_DECLARE(_1, Name) llvm::Type* Name##Ty_;
246   AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, LLVM_TYPE_DECLARE);
247 #undef LLVM_TYPE_DECLARE
248 
249 #if LLVM_VERSION_MAJOR >= 15
250   llvm::Type* OpqPtrTy_;
251 #else
252   llvm::Type* Int8PtrTy_;
253 #endif
254   llvm::Type* VoidTy_;
255   std::unordered_map<VarPtr, int> varToArg_;
256   std::unordered_map<VarPtr, llvm::Value*> varToVal_;
257   std::unordered_set<BufPtr> bufsExtAlloc_;
258   std::unordered_map<VarPtr, llvm::Value*> bufsExtToFreeVal_;
259   std::unordered_multimap<BufPtr, BufPtr> bufsExtAllocReuse_;
260   std::unordered_map<BlockPtr, std::vector<VarPtr>> scopeToVar_;
261   BlockPtr scope_;
262 
263   std::string llvmCode_;
264   std::string asmCode_;
265 
266  private:
267   llvm::LLVMContext& getContext();
268   llvm::Type* dtypeToLLVM(Dtype dtype);
269   llvm::Type* dtypeToLLVMPtr(Dtype dtype);
270   void emitWrapper(const std::vector<llvm::Type*>& params);
271   void emitKernel(StmtPtr stmt, const std::vector<llvm::Type*>& params);
272   llvm::Value* toVec(llvm::Value* v, int lanes);
273 
274   enum Arity {
275     Unary,
276     Binary,
277   };
278 
279   using SimdCallee = std::tuple<llvm::FunctionType*, llvm::Value*, bool>;
280   SimdCallee getSimdFunction(
281       const std::string& name,
282       llvm::Type* type,
283       Arity arity,
284       int lanes);
285 
286   llvm::Value* varToValue(VarPtr var);
287   void replaceVarMapping(
288       const std::vector<VarPtr>& vars,
289       const std::vector<llvm::Value*>& vals);
290 
291 #if LLVM_VERSION_MAJOR >= 15
292   TypedPointer packFuncArgs(const std::vector<llvm::Value*>& func_args);
293   std::vector<llvm::Value*> unpackFuncArgs(TypedPointer packed, int arg_count);
294 #else
295   llvm::Value* packFuncArgs(const std::vector<llvm::Value*>& func_args);
296   std::vector<llvm::Value*> unpackFuncArgs(llvm::Value* packed, int arg_count);
297 #endif
298 
299   void processParallelFor(ForPtr v);
300   void handleBufReuse(BufPtr buf, BufPtr buf_to_reuse);
301 
302  public:
303   LLVMCodeGenImpl(
304       StmtPtr stmt,
305       const std::vector<CodeGen::BufferArg>& args,
306       at::Device device,
307       Dtype dtype,
308       std::string kernel_func_name,
309       std::optional<std::string> triple,
310       std::optional<std::string> cpu,
311       std::optional<std::string> attrs);
312   ~LLVMCodeGenImpl() override = default;
313 
314   llvm::JITTargetAddress getKernelAddress() const;
315   std::unique_ptr<llvm::orc::PytorchLLVMJIT> releaseJIT();
316 
317   void visit(const AddPtr& v) override;
318   void visit(const SubPtr& v) override;
319   void visit(const MulPtr& v) override;
320   void visit(const DivPtr& v) override;
321   void visit(const ModPtr& v) override;
322   void visit(const MaxPtr& v) override;
323   void visit(const MinPtr& v) override;
324   void visit(const AndPtr& v) override;
325   void visit(const OrPtr& v) override;
326   void visit(const XorPtr& v) override;
327   void visit(const LshiftPtr& v) override;
328   void visit(const RshiftPtr& v) override;
329   void visit(const CompareSelectPtr& v) override;
330 
331 #define IMM_VISIT_DECLARE(_1, Name) void visit(const Name##ImmPtr& v) override;
332   AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT_DECLARE);
333 #undef IMM_VISIT_DECLARE
334 
335   void visit(const CastPtr& v) override;
336   void visit(const BitCastPtr& v) override;
337   void visit(const VarPtr& v) override;
338   void visit(const RampPtr& v) override;
339   void visit(const LoadPtr& v) override;
340   void visit(const ForPtr& v) override;
341   void visit(const BlockPtr& v) override;
342   void visit(const StorePtr& v) override;
343   void visit(const BroadcastPtr& v) override;
344   void visit(const IfThenElsePtr& v) override;
345   void visit(const IntrinsicsPtr& v) override;
346   void visit(const AllocatePtr& v) override;
347   void visit(const FreePtr& v) override;
348   void visit(const FreeExtPtr& v) override;
349   void visit(const PlacementAllocatePtr& v) override;
350   void visit(const LetPtr& v) override;
351   void visit(const CondPtr& v) override;
352   void visit(const ExternalCallPtr& v) override;
353   void visit(const ExternalCallWithAllocPtr& v) override;
354 
355   void emitIsNan(IntrinsicsPtr v);
356 
357   llvm::Value* emitUnmaskedLoad(
358       llvm::Type* ty,
359       llvm::Value* addr,
360       llvm::Value* idx);
361   llvm::Value* emitMaskedLoad(
362       llvm::Type* ty,
363       llvm::Value* addr,
364       llvm::Value* idx,
365       llvm::Value* mask);
366   void emitUnmaskedStore(
367       llvm::Type* ty,
368       llvm::Value* base,
369       llvm::Value* idx,
370       llvm::Value* val);
371   void emitMaskedStore(
372       llvm::Type* ty,
373       llvm::Value* base,
374       llvm::Value* idx,
375       llvm::Value* mask,
376       llvm::Value* val);
377 
378   void optimize(llvm::Module& M);
getLLVMCodeText()379   std::string getLLVMCodeText() {
380     return llvmCode_;
381   }
getASMCodeText()382   std::string getASMCodeText() {
383     return asmCode_;
384   }
385 };
386 
387 } // namespace torch::jit::tensorexpr
388 
389 LLVMCodeGen::~LLVMCodeGen() = default;
390 
LLVMCodeGen(StmtPtr stmt)391 LLVMCodeGen::LLVMCodeGen(StmtPtr stmt)
392     : LLVMCodeGen(stmt, std::vector<CodeGen::BufferArg>()) {}
393 
LLVMCodeGen(StmtPtr stmt,const std::vector<BufferArg> & args,at::Device device,const std::string & kernel_func_name,Dtype dtype,std::optional<std::string> triple,std::optional<std::string> cpu,std::optional<std::string> attrs)394 LLVMCodeGen::LLVMCodeGen(
395     StmtPtr stmt,
396     const std::vector<BufferArg>& args,
397     at::Device device,
398     const std::string& kernel_func_name,
399     Dtype dtype,
400     std::optional<std::string> triple,
401     std::optional<std::string> cpu,
402     std::optional<std::string> attrs)
403     : CodeGen(stmt, args, device, kernel_func_name) {
404   impl_ = std::make_unique<LLVMCodeGenImpl>(
405       this->stmt(),
406       args,
407       device,
408       dtype,
409       this->kernel_func_name(),
410       triple,
411       cpu,
412       attrs);
413   callee_ = std::make_unique<LLVMCodeGenCallee>(
414       impl_->releaseJIT(), (void*)impl_->getKernelAddress());
415 }
416 
cleanup_memory()417 void LLVMCodeGen::cleanup_memory() {
418   impl_.reset(nullptr);
419 }
420 
call_raw(const std::vector<void * > & args)421 void LLVMCodeGen::call_raw(const std::vector<void*>& args) {
422   value<float>(const_cast<void**>(args.data()));
423 }
424 
call_with_numel(void ** args,int64_t)425 void LLVMCodeGen::call_with_numel(void** args, int64_t /* numel */) {
426   value<float>(const_cast<void**>(args));
427 }
428 
call(const std::vector<CallArg> & args)429 void LLVMCodeGen::call(const std::vector<CallArg>& args) {
430   auto& buf_args = buffer_args();
431   if (args.size() != buf_args.size()) {
432     throw malformed_input("wrong number of args in call");
433   }
434 
435   constexpr unsigned nargs = 8;
436   c10::SmallVector<void*, nargs> argv;
437   argv.resize(buf_args.size());
438   for (size_t i = 0, e = buf_args.size(); i < e; i++) {
439     auto const& bufferArg = buf_args[i];
440     auto const& callArg = args[i];
441     argv[i] = argToPtr(bufferArg, callArg);
442   }
443   value<float>(argv.data());
444 }
445 
empty_strided(c10::IntArrayRef size,c10::IntArrayRef stride,std::optional<c10::ScalarType> dtype_opt,std::optional<c10::Layout> layout_opt,std::optional<c10::Device> device_opt,std::optional<bool> pin_memory_opt)446 at::Tensor LLVMCodeGen::empty_strided(
447     c10::IntArrayRef size,
448     c10::IntArrayRef stride,
449     std::optional<c10::ScalarType> dtype_opt,
450     std::optional<c10::Layout> layout_opt,
451     std::optional<c10::Device> device_opt,
452     std::optional<bool> pin_memory_opt) {
453   return at::native::empty_strided_cpu(
454       size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
455 }
456 
getKernelAddress(LLVMCodeGenCallee * callee)457 void* LLVMCodeGen::getKernelAddress(LLVMCodeGenCallee* callee) {
458   return (void*)callee->getKernelAddress();
459 }
460 
getCodeText(const std::string & attr)461 std::string LLVMCodeGen::getCodeText(const std::string& attr /*=""*/) {
462   TORCH_INTERNAL_ASSERT(
463       impl_.get(),
464       "LLVMCodeGen memory has been cleaned up. So, code text is not available at this point");
465   if (attr == "asm") {
466     return impl_->getASMCodeText();
467   } else {
468     return impl_->getLLVMCodeText();
469   }
470 }
471 
getKernelAddress() const472 llvm::JITTargetAddress LLVMCodeGenImpl::getKernelAddress() const {
473   return kernelAddress_;
474 }
475 
releaseJIT()476 std::unique_ptr<llvm::orc::PytorchLLVMJIT> LLVMCodeGenImpl::releaseJIT() {
477   return std::move(jit_);
478 }
479 
480 namespace {
481 // Global mutex to protect LLVM initialization.  TargetRegistry::lookupTarget
482 // in particular is not thread-safe.
483 static std::mutex llvmInitMutex;
484 } // namespace
485 
LLVMCodeGenImpl(StmtPtr stmt,const std::vector<CodeGen::BufferArg> & args,at::Device device,Dtype dtype,std::string kernel_func_name,std::optional<std::string> triple,std::optional<std::string> cpu,std::optional<std::string> attrs)486 LLVMCodeGenImpl::LLVMCodeGenImpl(
487     StmtPtr stmt,
488     const std::vector<CodeGen::BufferArg>& args,
489     at::Device device,
490     Dtype dtype,
491     std::string kernel_func_name,
492     std::optional<std::string> triple,
493     std::optional<std::string> cpu,
494     std::optional<std::string> attrs)
495     : context_(std::make_unique<llvm::LLVMContext>()),
496       irb_(getContext()),
497       kernel_func_name_(std::move(kernel_func_name)),
498       bufsExtAlloc_(ExternalAllocBufFinder::find(stmt)) {
499   if (!triple) {
500     triple = LLVMTargetTriple();
501   }
502   if (!cpu) {
503     cpu = LLVMTargetCPU();
504   }
505   if (!attrs) {
506     attrs = LLVMTargetAttrs();
507   }
508   // Manually map types to LLVM types.
509   ByteTy_ = llvm::Type::getInt8Ty(getContext());
510   CharTy_ = llvm::Type::getInt8Ty(getContext());
511   ShortTy_ = llvm::Type::getInt16Ty(getContext());
512   IntTy_ = llvm::Type::getInt32Ty(getContext());
513   LongTy_ = llvm::Type::getInt64Ty(getContext());
514   HalfTy_ = llvm::Type::getHalfTy(getContext());
515   FloatTy_ = llvm::Type::getFloatTy(getContext());
516   DoubleTy_ = llvm::Type::getDoubleTy(getContext());
517   VoidTy_ = llvm::Type::getVoidTy(getContext());
518   BoolTy_ = ByteTy_;
519 #if LLVM_VERSION_MAJOR >= 15
520   OpqPtrTy_ = llvm::PointerType::getUnqual(getContext());
521 #else
522   Int8PtrTy_ = llvm::Type::getInt8PtrTy(getContext());
523 #endif
524 
525   {
526     std::lock_guard<std::mutex> g(llvmInitMutex);
527     llvm::InitializeAllTargets();
528     llvm::InitializeAllTargetMCs();
529     llvm::InitializeAllAsmPrinters();
530     jit_ = std::make_unique<llvm::orc::PytorchLLVMJIT>(triple, cpu, attrs);
531   }
532 
533   module_ = std::make_unique<llvm::Module>("pytorch", getContext());
534   module_->setDataLayout(jit_->getDataLayout());
535   module_->setTargetTriple(jit_->getTargetMachine().getTargetTriple().str());
536 
537   // We support float16 ops by casting expr inputs to float32
538   // and then casting the result back to float16
539 
540   GRAPH_DEBUG("Before HalfRewriter ", *stmt);
541   HalfRewriter hsFix;
542   stmt = stmt->accept_mutator(&hsFix);
543   GRAPH_DEBUG("After HalfRewriter ", *stmt);
544 
545   // Emit prototype and bind argument Vars to parameter indices.
546   llvm::Type* retTy = dtypeToLLVM(dtype);
547   std::vector<llvm::Type*> params;
548   for (const auto i : c10::irange(args.size())) {
549     auto const& arg = args[i];
550     if (arg.isVar()) {
551       params.push_back(dtypeToLLVM(arg.dtype()));
552     } else {
553       params.push_back(dtypeToLLVMPtr(arg.dtype()));
554     }
555     varToArg_[arg.var()] = i;
556   }
557   llvm::FunctionType* fntype = llvm::FunctionType::get(retTy, params, false);
558   fn_ = llvm::Function::Create(
559       fntype, llvm::Function::PrivateLinkage, "pytorch", module_.get());
560   fn_->addFnAttr(llvm::Attribute::AlwaysInline);
561   for (const auto i : c10::irange(args.size())) {
562     if (!args[i].isVar()) {
563       fn_->addParamAttr(i, llvm::Attribute::NoAlias);
564     }
565   }
566 
567   emitWrapper(params);
568   emitKernel(stmt, params);
569 
570   jit_->addModule(std::move(module_), std::move(context_));
571   if (!LLVMAOTWorkflow()) {
572     auto sym = jit_->findSymbol(kernel_func_name_);
573     kernelAddress_ = assertSuccess(sym.getAddress());
574   }
575 }
576 
getContext()577 llvm::LLVMContext& LLVMCodeGenImpl::getContext() {
578   return *context_;
579 }
580 
dtypeToLLVM(Dtype dtype)581 llvm::Type* LLVMCodeGenImpl::dtypeToLLVM(Dtype dtype) {
582   switch (dtype.scalar_type()) {
583 #define TYPE_CASE(_1, n) \
584   case ScalarType::n:    \
585     return n##Ty_;       \
586     break;
587 
588     AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
589 #undef TYPE_CASE
590     case ScalarType::QInt8:
591       return CharTy_;
592       break;
593 
594     case ScalarType::QUInt8:
595       return ByteTy_;
596       break;
597 
598     case ScalarType::BFloat16:
599       return ShortTy_;
600       break;
601 
602     default:
603       throw unsupported_dtype();
604   }
605   return nullptr;
606 }
607 
dtypeToLLVMPtr(Dtype dtype)608 llvm::Type* LLVMCodeGenImpl::dtypeToLLVMPtr(Dtype dtype) {
609   return dtypeToLLVM(dtype)->getPointerTo();
610 }
611 
emitWrapper(const std::vector<llvm::Type * > & params)612 void LLVMCodeGenImpl::emitWrapper(const std::vector<llvm::Type*>& params) {
613 #if LLVM_VERSION_MAJOR >= 15
614   auto wrapper = llvm::Function::Create(
615       llvm::FunctionType::get(IntTy_, {OpqPtrTy_}, false),
616       llvm::Function::ExternalLinkage,
617       kernel_func_name_,
618       module_.get());
619 #else
620   auto voidPtrTy = llvm::Type::getInt8PtrTy(getContext());
621   auto voidPtrPtrTy = voidPtrTy->getPointerTo();
622   auto wrapper = llvm::Function::Create(
623       llvm::FunctionType::get(IntTy_, {voidPtrPtrTy}, false),
624       llvm::Function::ExternalLinkage,
625       kernel_func_name_,
626       module_.get());
627 #endif
628 
629   {
630     // Work around UBSAN crashes which reads 8 byte in front of every function.
631     // Otherwise, if the function was placed at the beginning of a page, reading
632     // 8B before the page could trigger a wild-addr-read ASAN failure if the
633     // page before this function was not mapped.
634     // - https://reviews.llvm.org/D148665
635     // - https://github.com/llvm/llvm-project/issues/65253
636     // Place the variable just before the function,
637     // the optimizer might otherwise disable this workaround.
638     // https://llvm.org/docs/LangRef.html#prefix-data
639     wrapper->setPrefixData(llvm::Constant::getNullValue(
640         llvm::ArrayType::get(llvm::Type::getInt8Ty(getContext()), 8)));
641   }
642 
643   auto wrapBB = llvm::BasicBlock::Create(getContext(), "wrapBB", wrapper);
644   irb_.SetInsertPoint(wrapBB);
645   llvm::SmallVector<llvm::Value*, 6> wrappedArgs;
646   for (const auto i : c10::irange(params.size())) {
647 #if LLVM_VERSION_MAJOR >= 15
648     auto argp = irb_.CreateGEP(
649         OpqPtrTy_,
650         wrapper->arg_begin(),
651         llvm::ConstantInt::getSigned(IntTy_, i));
652     if (params[i]->isPointerTy()) {
653       auto arg =
654           irb_.CreatePointerCast(irb_.CreateLoad(OpqPtrTy_, argp), params[i]);
655       wrappedArgs.push_back(arg);
656     } else {
657       auto p =
658           irb_.CreatePointerCast(irb_.CreateLoad(OpqPtrTy_, argp), OpqPtrTy_);
659       auto arg = irb_.CreateLoad(params[i], p);
660       wrappedArgs.push_back(arg);
661     }
662 #else
663     auto argp = irb_.CreateGEP(
664         voidPtrTy,
665         wrapper->arg_begin(),
666         llvm::ConstantInt::getSigned(IntTy_, i));
667     if (params[i]->isPointerTy()) {
668       auto arg = irb_.CreatePointerCast(
669           irb_.CreateLoad(argp->getType()->getPointerElementType(), argp),
670           params[i]);
671       wrappedArgs.push_back(arg);
672     } else {
673       auto p = irb_.CreatePointerCast(
674           irb_.CreateLoad(argp->getType()->getPointerElementType(), argp),
675           params[i]->getPointerTo());
676       auto arg = irb_.CreateLoad(p->getType()->getPointerElementType(), p);
677       wrappedArgs.push_back(arg);
678     }
679 #endif
680   }
681   auto cc = irb_.CreateCall(fn_, wrappedArgs);
682   irb_.CreateRet(cc);
683 }
684 
685 class LLVMIntrinsicsExpander : public GenericIntrinsicsExpander {
686  private:
mutate(const IntrinsicsPtr & v)687   ExprPtr mutate(const IntrinsicsPtr& v) override {
688     if (v->op_type() == kTanh) {
689       ScalarType stype = v->dtype().scalar_type();
690       if (stype == ScalarType::Float) {
691         return fast_tanh(ExprHandle(v->param(0)->accept_mutator(this))).node();
692       }
693     } else if (v->op_type() == kSigmoid) {
694       ScalarType stype = v->dtype().scalar_type();
695       if (stype == ScalarType::Float) {
696         return fast_sigmoid(ExprHandle(v->param(0)->accept_mutator(this)))
697             .node();
698       }
699     }
700     // TODO: fast exp
701     // TODO: fast erf
702     // TODO: fast sigmoid
703     return GenericIntrinsicsExpander::mutate(v);
704   }
705 };
706 
emitKernel(StmtPtr stmt,const std::vector<llvm::Type * > & params)707 void LLVMCodeGenImpl::emitKernel(
708     StmtPtr stmt,
709     const std::vector<llvm::Type*>& params) {
710   // Set insert point to the real function.
711   bb_ = llvm::BasicBlock::Create(getContext(), "entry", fn_);
712   irb_.SetInsertPoint(bb_);
713 
714   // Maybe expand some of the intrinsics.
715   if (FLAGS_torch_jit_llvm_use_fast_intrinsics) {
716     LLVMIntrinsicsExpander intrinsics_expander;
717     stmt = stmt->accept_mutator(&intrinsics_expander);
718   } else {
719     GenericIntrinsicsExpander intrinsics_expander;
720     stmt = stmt->accept_mutator(&intrinsics_expander);
721   }
722 
723   // Compile the kernel.
724   stmt->accept(this);
725 
726   // If the kernel is empty, set a default return value.
727   if (value_ == nullptr) {
728     value_ = llvm::ConstantInt::get(IntTy_, 0);
729   }
730 
731   irb_.CreateRet(value_);
732 
733   // print graph debug info before optimization
734   llvm::SmallVector<char, 0> asmBuffer;
735   llvm::raw_svector_ostream asmStream(asmBuffer);
736   if (GRAPH_DEBUG_ENABLED) {
737     module_->print(asmStream, nullptr);
738   }
739   GRAPH_DEBUG(
740       "\nLLVM module before optimizations\n\n", asmStream.str().str(), "\n");
741 
742   if (llvm::verifyFunction(*fn_, &llvm::outs())) {
743     throw std::runtime_error("Function verification failed");
744   }
745 
746   optimize(*module_);
747 
748   asmBuffer.clear();
749   module_->print(asmStream, nullptr);
750   llvmCode_ = asmStream.str().str();
751   GRAPH_DEBUG(
752       "\nLLVM module after optimizations\n\n", asmStream.str().str(), "\n");
753 
754   // print graph debug info after optimization
755   asmBuffer.clear();
756   llvm::legacy::PassManager PM;
757   jit_->getTargetMachine().addPassesToEmitFile(
758       PM,
759       asmStream,
760       nullptr,
761 #if LLVM_VERSION_MAJOR >= 18
762       llvm::CodeGenFileType::AssemblyFile);
763 #elif LLVM_VERSION_MAJOR >= 10
764       llvm::CodeGenFileType::CGFT_AssemblyFile);
765 #else
766       llvm::TargetMachine::CodeGenFileType::CGFT_AssemblyFile);
767 #endif
768   PM.run(*module_);
769   asmCode_ = asmStream.str().str();
770 
771   GRAPH_DEBUG("\nLLVM generated assembly code\n\n", asmCode_, "\n");
772 }
773 
774 // TODO: The binary ops are copypasta.
775 
visit(const AddPtr & v)776 void LLVMCodeGenImpl::visit(const AddPtr& v) {
777   v->lhs()->accept(this);
778   auto lhs = this->value_;
779   bool lfp = lhs->getType()->isFPOrFPVectorTy();
780   v->rhs()->accept(this);
781   auto rhs = this->value_;
782   bool rfp = rhs->getType()->isFPOrFPVectorTy();
783 
784   // TODO: Handle arg promotion.
785   if (lfp && rfp) {
786     value_ = irb_.CreateFAdd(lhs, rhs);
787   } else if (!lfp && !rfp) {
788     value_ = irb_.CreateAdd(lhs, rhs);
789   } else {
790     throw malformed_input("llvm_codegen: bad type in Add", v);
791   }
792 }
793 
visit(const SubPtr & v)794 void LLVMCodeGenImpl::visit(const SubPtr& v) {
795   v->lhs()->accept(this);
796   auto lhs = this->value_;
797   bool lfp = lhs->getType()->isFPOrFPVectorTy();
798   v->rhs()->accept(this);
799   auto rhs = this->value_;
800   bool rfp = rhs->getType()->isFPOrFPVectorTy();
801 
802   // TODO: Handle arg promotion.
803   if (lfp && rfp) {
804     value_ = irb_.CreateFSub(lhs, rhs);
805   } else if (!lfp && !rfp) {
806     value_ = irb_.CreateSub(lhs, rhs);
807   } else {
808     throw malformed_input("llvm_codegen: bad type in Sub", v);
809   }
810 }
811 
visit(const MulPtr & v)812 void LLVMCodeGenImpl::visit(const MulPtr& v) {
813   v->lhs()->accept(this);
814   auto lhs = this->value_;
815   bool lfp = lhs->getType()->isFPOrFPVectorTy();
816   v->rhs()->accept(this);
817   auto rhs = this->value_;
818   bool rfp = rhs->getType()->isFPOrFPVectorTy();
819 
820   // TODO: Handle arg promotion.
821   if (lfp && rfp) {
822     value_ = irb_.CreateFMul(lhs, rhs);
823   } else if (!lfp && !rfp) {
824     value_ = irb_.CreateMul(lhs, rhs);
825   } else {
826     throw malformed_input("llvm_codegen: bad type in Mul", v);
827   }
828 }
829 
visit(const DivPtr & v)830 void LLVMCodeGenImpl::visit(const DivPtr& v) {
831   v->lhs()->accept(this);
832   auto lhs = this->value_;
833   bool lfp = lhs->getType()->isFPOrFPVectorTy();
834   v->rhs()->accept(this);
835   auto rhs = this->value_;
836   bool rfp = rhs->getType()->isFPOrFPVectorTy();
837 
838   // TODO: Handle arg promotion.
839   if (lfp && rfp) {
840     value_ = irb_.CreateFDiv(lhs, rhs);
841   } else if (!lfp && !rfp) {
842     value_ = irb_.CreateSDiv(lhs, rhs);
843   } else {
844     throw malformed_input("llvm_codegen: bad type in Div", v);
845   }
846 }
847 
visit(const AndPtr & v)848 void LLVMCodeGenImpl::visit(const AndPtr& v) {
849   v->lhs()->accept(this);
850   auto lhs = this->value_;
851   bool lfp = lhs->getType()->isFPOrFPVectorTy();
852   v->rhs()->accept(this);
853   auto rhs = this->value_;
854   bool rfp = rhs->getType()->isFPOrFPVectorTy();
855 
856   if (!lfp && !rfp) {
857     value_ = irb_.CreateAnd(lhs, rhs);
858   } else {
859     throw malformed_input("llvm_codegen: bad type in And", v);
860   }
861 }
862 
visit(const OrPtr & v)863 void LLVMCodeGenImpl::visit(const OrPtr& v) {
864   v->lhs()->accept(this);
865   auto lhs = this->value_;
866   bool lfp = lhs->getType()->isFPOrFPVectorTy();
867   v->rhs()->accept(this);
868   auto rhs = this->value_;
869   bool rfp = rhs->getType()->isFPOrFPVectorTy();
870 
871   if (!lfp && !rfp) {
872     value_ = irb_.CreateOr(lhs, rhs);
873   } else {
874     throw malformed_input("llvm_codegen: bad type in Or", v);
875   }
876 }
877 
visit(const XorPtr & v)878 void LLVMCodeGenImpl::visit(const XorPtr& v) {
879   v->lhs()->accept(this);
880   auto lhs = this->value_;
881   bool lfp = lhs->getType()->isFPOrFPVectorTy();
882   v->rhs()->accept(this);
883   auto rhs = this->value_;
884   bool rfp = rhs->getType()->isFPOrFPVectorTy();
885 
886   if (!lfp && !rfp) {
887     value_ = irb_.CreateXor(lhs, rhs);
888   } else {
889     throw malformed_input("llvm_codegen: bad type in Xor", v);
890   }
891 }
892 
visit(const LshiftPtr & v)893 void LLVMCodeGenImpl::visit(const LshiftPtr& v) {
894   v->lhs()->accept(this);
895   auto lhs = this->value_;
896   bool lfp = lhs->getType()->isFPOrFPVectorTy();
897   v->rhs()->accept(this);
898   auto rhs = this->value_;
899   bool rfp = rhs->getType()->isFPOrFPVectorTy();
900 
901   if (!lfp && !rfp) {
902     value_ = irb_.CreateShl(lhs, rhs);
903   } else {
904     throw malformed_input("llvm_codegen: bad type in Lshift", v);
905   }
906 }
907 
visit(const RshiftPtr & v)908 void LLVMCodeGenImpl::visit(const RshiftPtr& v) {
909   v->lhs()->accept(this);
910   auto lhs = this->value_;
911   bool lfp = lhs->getType()->isFPOrFPVectorTy();
912   v->rhs()->accept(this);
913   auto rhs = this->value_;
914   bool rfp = rhs->getType()->isFPOrFPVectorTy();
915 
916   if (!lfp && !rfp) {
917     if (v->lhs()->dtype().is_signed()) {
918       value_ = irb_.CreateAShr(lhs, rhs);
919     } else {
920       value_ = irb_.CreateLShr(lhs, rhs);
921     }
922   } else {
923     throw malformed_input("llvm_codegen: bad type in Rshift", v);
924   }
925 }
926 
visit(const ModPtr & v)927 void LLVMCodeGenImpl::visit(const ModPtr& v) {
928   v->lhs()->accept(this);
929   auto lhs = this->value_;
930   bool lfp = lhs->getType()->isFPOrFPVectorTy();
931   v->rhs()->accept(this);
932   auto rhs = this->value_;
933   bool rfp = rhs->getType()->isFPOrFPVectorTy();
934 
935   if (!lfp && !rfp) {
936     value_ = irb_.CreateSRem(lhs, rhs);
937   } else {
938     throw malformed_input("llvm_codegen: bad type in Mod", v);
939   }
940 }
941 
visit(const MaxPtr & v)942 void LLVMCodeGenImpl::visit(const MaxPtr& v) {
943   v->lhs()->accept(this);
944   auto lhs = this->value_;
945   v->rhs()->accept(this);
946   auto rhs = this->value_;
947 
948   if (v->dtype().is_integral()) {
949     auto icmp = v->dtype().is_signed() ? irb_.CreateICmpSGT(lhs, rhs)
950                                        : irb_.CreateICmpUGT(lhs, rhs);
951     value_ = irb_.CreateSelect(icmp, lhs, rhs);
952     return;
953   }
954 
955   value_ = irb_.CreateSelect(
956       irb_.CreateFCmp(
957           llvm::FCmpInst::FCMP_UNO,
958           lhs,
959           llvm::ConstantFP::get(lhs->getType(), 0.0)),
960       lhs,
961       irb_.CreateSelect(
962           irb_.CreateFCmp(llvm::FCmpInst::FCMP_OGT, lhs, rhs), lhs, rhs));
963 }
964 
visit(const MinPtr & v)965 void LLVMCodeGenImpl::visit(const MinPtr& v) {
966   v->lhs()->accept(this);
967   auto lhs = this->value_;
968   v->rhs()->accept(this);
969   auto rhs = this->value_;
970   if (v->dtype().is_integral()) {
971     auto icmp = v->dtype().is_signed() ? irb_.CreateICmpSLT(lhs, rhs)
972                                        : irb_.CreateICmpULT(lhs, rhs);
973     value_ = irb_.CreateSelect(icmp, lhs, rhs);
974     return;
975   }
976 
977   value_ = irb_.CreateSelect(
978       irb_.CreateFCmp(
979           llvm::FCmpInst::FCMP_UNO,
980           lhs,
981           llvm::ConstantFP::get(lhs->getType(), 0.0)),
982       lhs,
983       irb_.CreateSelect(
984           irb_.CreateFCmp(llvm::FCmpInst::FCMP_OLT, lhs, rhs), lhs, rhs));
985 }
986 
visit(const CompareSelectPtr & v)987 void LLVMCodeGenImpl::visit(const CompareSelectPtr& v) {
988   auto genUnbiased = [this, v]() -> llvm::Value* {
989     v->lhs()->accept(this);
990     auto lhs = this->value_;
991     v->rhs()->accept(this);
992     auto rhs = this->value_;
993     v->ret_val1()->accept(this);
994     auto retval1 = this->value_;
995     v->ret_val2()->accept(this);
996     auto retval2 = this->value_;
997 
998     auto type_used = v->lhs()->dtype().scalar_type();
999 
1000     llvm::Value* cmp_;
1001     CompareSelectOperation cmp_op_ = v->compare_select_op();
1002 
1003     if (c10::isIntegralType(type_used, true)) {
1004       cmp_ = irb_.CreateICmp(
1005           llvm_comparison_predicate(cmp_op_, type_used), lhs, rhs);
1006     } else if (c10::isFloatingType(type_used)) {
1007       cmp_ = irb_.CreateFCmp(llvm_fp_comparison_predicate(cmp_op_), lhs, rhs);
1008     } else {
1009       throw std::runtime_error("invalid type for CompareSelect");
1010     }
1011 
1012     return irb_.CreateSelect(cmp_, retval1, retval2);
1013   };
1014 
1015   auto genBiased = [this, v]() -> llvm::Value* {
1016     v->lhs()->accept(this);
1017     auto lhs = this->value_;
1018     v->rhs()->accept(this);
1019     auto rhs = this->value_;
1020 
1021     auto cmp_type = v->lhs()->dtype().scalar_type();
1022     auto cmp_op = v->compare_select_op();
1023     llvm::Value* cmp;
1024 
1025     if (c10::isIntegralType(cmp_type, true)) {
1026       cmp = irb_.CreateICmp(
1027           llvm_comparison_predicate(cmp_op, cmp_type), lhs, rhs);
1028     } else if (c10::isFloatingType(cmp_type)) {
1029       cmp = irb_.CreateFCmp(llvm_fp_comparison_predicate(cmp_op), lhs, rhs);
1030     } else {
1031       throw std::runtime_error("invalid type for CompareSelect");
1032     }
1033 
1034     auto lanes = v->lhs()->dtype().lanes();
1035     if (lanes > 1) {
1036       auto maskType = llvm::Type::getIntNTy(getContext(), lanes);
1037       auto zero = llvm::ConstantInt::get(maskType, 0);
1038       auto mask = irb_.CreateBitOrPointerCast(cmp, maskType);
1039       cmp = irb_.CreateICmpNE(mask, zero);
1040     }
1041 
1042     auto then_block = llvm::BasicBlock::Create(getContext(), "then", fn_);
1043     auto else_block = llvm::BasicBlock::Create(getContext(), "else", fn_);
1044     auto end_block = llvm::BasicBlock::Create(getContext(), "block", fn_);
1045     constexpr int32_t total_weight = 100000;
1046     auto true_weight = v->bias() == kLikely ? total_weight : 0;
1047     auto false_weight = total_weight - true_weight;
1048     irb_.CreateCondBr(
1049         cmp,
1050         then_block,
1051         else_block,
1052         llvm::MDBuilder(getContext())
1053             .createBranchWeights(true_weight, false_weight));
1054 
1055     irb_.SetInsertPoint(then_block);
1056     v->ret_val1()->accept(this);
1057     llvm::Value* then_val = value_;
1058     then_block = irb_.GetInsertBlock();
1059     irb_.CreateBr(end_block);
1060 
1061     irb_.SetInsertPoint(else_block);
1062     v->ret_val2()->accept(this);
1063     llvm::Value* else_val = value_;
1064     else_block = irb_.GetInsertBlock();
1065     irb_.CreateBr(end_block);
1066 
1067     irb_.SetInsertPoint(end_block);
1068     llvm::PHINode* phi = irb_.CreatePHI(then_val->getType(), 2);
1069     phi->addIncoming(then_val, then_block);
1070     phi->addIncoming(else_val, else_block);
1071     return phi;
1072   };
1073 
1074   value_ = v->bias() == kUnbiased ? genUnbiased() : genBiased();
1075 }
1076 
1077 template <typename T>
1078 typename std::enable_if<std::is_integral<T>::value, llvm::Value*>::type
getFromType(llvm::Type * type,T value)1079 getFromType(llvm::Type* type, T value) {
1080   return llvm::ConstantInt::get(type, value, std::is_signed<T>::value);
1081 }
1082 
1083 template <typename T>
1084 typename std::enable_if<std::is_floating_point<T>::value, llvm::Value*>::type
getFromType(llvm::Type * type,T value)1085 getFromType(llvm::Type* type, T value) {
1086   return llvm::ConstantFP::get(type, value);
1087 }
1088 
1089 #define IMM_VISIT_DECLARE(Type, Name)                  \
1090   void LLVMCodeGenImpl::visit(const Name##ImmPtr& v) { \
1091     value_ = getFromType<Type>(Name##Ty_, v->value()); \
1092   }
1093 AT_FORALL_SCALAR_TYPES(IMM_VISIT_DECLARE);
1094 #undef IMM_VISIT_DECLARE
1095 
visit(const HalfImmPtr & v)1096 void LLVMCodeGenImpl::visit(const HalfImmPtr& v) {
1097   value_ = llvm::ConstantFP::get(HalfTy_, v->value());
1098 }
1099 
visit(const BFloat16ImmPtr & v)1100 void LLVMCodeGenImpl::visit(const BFloat16ImmPtr& v) {
1101   value_ = llvm::ConstantInt::get(ShortTy_, v->value().x);
1102 }
1103 
visit(const BoolImmPtr & v)1104 void LLVMCodeGenImpl::visit(const BoolImmPtr& v) {
1105   value_ = llvm::ConstantInt::get(BoolTy_, v->value());
1106 }
1107 
llvmTypeToVec(llvm::Type * type,int lanes)1108 static llvm::Type* llvmTypeToVec(llvm::Type* type, int lanes) {
1109   if (lanes > 1) {
1110     return llvm::VectorType::get(type, ElementCount(lanes));
1111   } else {
1112     return type;
1113   }
1114 }
1115 
visit(const CastPtr & v)1116 void LLVMCodeGenImpl::visit(const CastPtr& v) {
1117   v->src_value()->accept(this);
1118 
1119   auto dst_type = v->dtype().scalar_type();
1120   auto src_type = v->src_value()->dtype().scalar_type();
1121   bool is_to_bf16 = (dst_type == c10::kBFloat16);
1122   bool is_to_float = (dst_type == c10::kFloat);
1123   bool is_from_bf16 = (src_type == c10::kBFloat16);
1124   bool is_from_float = (src_type == c10::kFloat);
1125 
1126   bool cast_from_bf16_to_fp32 = is_from_bf16 && is_to_float;
1127   bool cast_from_fp32_to_bf16 = is_from_float && is_to_bf16;
1128   bool non_bf16_cast = (!is_to_bf16) && (!is_from_bf16);
1129   bool valid_bf16_cast = cast_from_bf16_to_fp32 || cast_from_fp32_to_bf16;
1130   TORCH_CHECK(
1131       valid_bf16_cast || non_bf16_cast,
1132       "Cast is not implemented for the conversion between ",
1133       src_type,
1134       " and ",
1135       dst_type,
1136       ".");
1137 
1138   llvm::Type* dstType =
1139       llvmTypeToVec(dtypeToLLVM(v->dtype()), v->dtype().lanes());
1140   llvm::Type* srcType = dtypeToLLVM(v->src_value()->dtype());
1141 
1142   if (srcType == dstType) {
1143     // do nothing.
1144     return;
1145   }
1146 
1147   bool destUnsigned = v->dtype().scalar_type() == ScalarType::Byte ||
1148       v->dtype().scalar_type() == ScalarType::QUInt8 ||
1149       v->dtype().scalar_type() == ScalarType::Bool;
1150   bool srcUnsigned =
1151       v->src_value()->dtype().scalar_type() == ScalarType::Byte ||
1152       v->src_value()->dtype().scalar_type() == ScalarType::QUInt8 ||
1153       v->src_value()->dtype().scalar_type() == ScalarType::Bool;
1154 
1155   // Scalar casts
1156   if (is_from_bf16) {
1157     // Shift the BF16 value left by 16bits and then bit cast the shifted value
1158     // to FP32.
1159     //   FP32_VAL = BF16_VAL << 16
1160     auto lans = v->dtype().lanes();
1161     value_ = irb_.CreateZExt(value_, llvmTypeToVec(IntTy_, lans));
1162     auto vec_shl_val = toVec(llvm::ConstantInt::get(IntTy_, 16), lans);
1163     value_ = irb_.CreateShl(value_, vec_shl_val);
1164     value_ = irb_.CreateBitOrPointerCast(value_, llvmTypeToVec(FloatTy_, lans));
1165     return;
1166   }
1167 
1168   if (is_to_bf16) {
1169     // Convert the FP32 value by RNE(Rounding to Nearest Even). Algorithm is as
1170     // follows:
1171     //   STEP1: U32_VAL = BITCAST(F32_VAL)
1172     //   STEP2: U32_VAL_TMP = U32_VAL >> 16
1173     //   STEP3: U32_VAL_TMP = U32_VAL_TMP & 1
1174     //   STEP4: ROUNDING_BIAS = U32_VAL_TMP + UINT32(0x7FFF)
1175     //   STEP5: U32_VAL_TMP = U32_VAL + ROUNDING_BIAS
1176     //   STEP6: BF16_VAL = static_cast<UINT16>(U32_VAL_TMP >> 16)
1177     auto lans = v->src_value()->dtype().lanes();
1178     auto shift_len = llvm::ConstantInt::get(IntTy_, 16);
1179     auto one = llvm::ConstantInt::get(ShortTy_, 1);
1180     auto rounding_bias = llvm::ConstantInt::get(ShortTy_, 0x7FFF);
1181     auto bf16_nan = llvm::ConstantInt::get(ShortTy_, 0xFFFF);
1182 
1183     auto mask = irb_.CreateFCmpOEQ(value_, value_);
1184     // STEP1: U32_VAL = BITCAST(F32_VAL)
1185     auto fp32_i32_value =
1186         irb_.CreateBitOrPointerCast(value_, llvmTypeToVec(IntTy_, lans));
1187     // STEP2: U32_VAL_TMP = (U32_VAL >> 16)
1188     value_ = irb_.CreateLShr(fp32_i32_value, toVec(shift_len, lans));
1189     value_ = irb_.CreateTrunc(value_, llvmTypeToVec(ShortTy_, lans));
1190     // STEP3: U32_VAL_TMP = U32_VAL_TMP & 1
1191     value_ = irb_.CreateAnd(value_, toVec(one, lans));
1192     // STEP4: ROUNDING_BIAS = U32_VAL_TMP + UINT32(0x7FFF)
1193     value_ = irb_.CreateAdd(value_, toVec(rounding_bias, lans));
1194     value_ = irb_.CreateZExt(value_, llvmTypeToVec(IntTy_, lans));
1195     // STEP5: U32_VAL_TMP = U32_VAL + ROUNDING_BIAS
1196     value_ = irb_.CreateAdd(value_, fp32_i32_value);
1197     // STEP6: BF16_VAL = static_cast<UINT16>(U32_VAL_TMP >> 16)
1198     value_ = irb_.CreateLShr(value_, toVec(shift_len, lans));
1199     value_ = irb_.CreateTrunc(value_, llvmTypeToVec(ShortTy_, lans));
1200     value_ = irb_.CreateBitOrPointerCast(value_, llvmTypeToVec(ShortTy_, lans));
1201     // If the value is NaN, return BF16 NaN.
1202     value_ = irb_.CreateSelect(mask, value_, toVec(bf16_nan, lans));
1203     return;
1204   }
1205 
1206   if (srcType->isFPOrFPVectorTy()) {
1207     if (dstType->isFPOrFPVectorTy()) {
1208       // as with eager, convert from Double -> Half by Converting to Float then
1209       // Half. TODO: __truncdfhf2
1210       if (v->dtype().scalar_type() == ScalarType::Half &&
1211           v->src_value()->dtype().scalar_type() == ScalarType::Double) {
1212         value_ = irb_.CreateFPCast(
1213             value_, llvmTypeToVec(FloatTy_, v->dtype().lanes()));
1214       }
1215       value_ = irb_.CreateFPCast(value_, dstType);
1216     } else if (dstType->isIntOrIntVectorTy()) {
1217       // Strictly casting from Float -> i8 doesnt give correct results
1218       // set one bit true if the input float is not 0
1219       if (v->dtype().scalar_type() == ScalarType::Bool) {
1220         llvm::Value* zero =
1221             toVec(llvm::ConstantFP::get(srcType, 0.), v->dtype().lanes());
1222         value_ = irb_.CreateFCmp(llvm::FCmpInst::FCMP_UNE, value_, zero);
1223         value_ = irb_.CreateICmpEQ(
1224             value_, llvm::ConstantInt::get(value_->getType(), 1));
1225         value_ = irb_.CreateIntCast(value_, dstType, !destUnsigned);
1226         return;
1227       }
1228 
1229       if (destUnsigned) {
1230         value_ = irb_.CreateFPToUI(value_, dstType);
1231       } else {
1232         value_ = irb_.CreateFPToSI(value_, dstType);
1233       }
1234     } else {
1235       throw unimplemented_lowering(v);
1236     }
1237     return;
1238   }
1239 
1240   if (!srcType->isIntOrIntVectorTy()) {
1241     throw unimplemented_lowering(v);
1242   }
1243   if (dstType->isFPOrFPVectorTy()) {
1244     if (srcUnsigned) {
1245       value_ = irb_.CreateUIToFP(value_, dstType);
1246     } else {
1247       value_ = irb_.CreateSIToFP(value_, dstType);
1248     }
1249   } else if (dstType->isIntOrIntVectorTy()) {
1250     // Ensure bool true value is exactly one, since we convert to int
1251     // from bool by zero extending the int8
1252     if (v->dtype().scalar_type() == ScalarType::Bool) {
1253       llvm::Value* zero =
1254           toVec(llvm::ConstantInt::get(srcType, 0), v->dtype().lanes());
1255       value_ = irb_.CreateICmpNE(value_, zero);
1256     }
1257     value_ = irb_.CreateIntCast(value_, dstType, !destUnsigned);
1258   } else {
1259     throw unimplemented_lowering(v);
1260   }
1261 }
1262 
visit(const BitCastPtr & v)1263 void LLVMCodeGenImpl::visit(const BitCastPtr& v) {
1264   v->src_value()->accept(this);
1265 
1266   llvm::Type* dstType = dtypeToLLVM(v->dtype());
1267   if (v->dtype().lanes() > 1) {
1268     dstType = llvm::VectorType::get(dstType, ElementCount(v->dtype().lanes()));
1269   }
1270   llvm::Type* srcType = dtypeToLLVM(v->src_value()->dtype());
1271 
1272   if (srcType == dstType) {
1273     // do nothing.
1274     return;
1275   }
1276 
1277   TORCH_CHECK(llvm::CastInst::isBitCastable(
1278       srcType->getScalarType(), dstType->getScalarType()));
1279   value_ = irb_.CreateBitOrPointerCast(value_, dstType);
1280 }
1281 
visit(const VarPtr & v)1282 void LLVMCodeGenImpl::visit(const VarPtr& v) {
1283   value_ = varToValue(v);
1284 }
1285 
varToValue(VarPtr v)1286 llvm::Value* LLVMCodeGenImpl::varToValue(VarPtr v) {
1287   // It is possible for v to be in both varToVal_ and varToArgs.
1288   // In that case, varToVal_ takes precedence.
1289   if (varToVal_.count(v)) {
1290     return varToVal_.at(v);
1291   } else if (varToArg_.count(v)) {
1292     auto idx = varToArg_.at(v);
1293     auto arg = fn_->arg_begin() + idx;
1294     return arg;
1295   }
1296   return nullptr;
1297 }
1298 
replaceVarMapping(const std::vector<VarPtr> & vars,const std::vector<llvm::Value * > & vals)1299 void LLVMCodeGenImpl::replaceVarMapping(
1300     const std::vector<VarPtr>& vars,
1301     const std::vector<llvm::Value*>& vals) {
1302   TORCH_CHECK(vars.size() == vals.size());
1303   for (const auto i : c10::irange(vars.size())) {
1304     VarPtr var = vars[i];
1305     llvm::Value* val = vals[i];
1306     if (val) {
1307       varToVal_[var] = val;
1308     } else {
1309       varToVal_.erase(var);
1310     }
1311   }
1312 }
1313 
visit(const RampPtr & v)1314 void LLVMCodeGenImpl::visit(const RampPtr& v) {
1315   v->base()->accept(this);
1316   auto base = this->value_;
1317   v->stride()->accept(this);
1318   auto stride = this->value_;
1319   int lanes = v->lanes();
1320 
1321   if (llvm::ConstantInt* const_stride =
1322           llvm::dyn_cast<llvm::ConstantInt>(stride)) {
1323     std::vector<llvm::Constant*> vals = {
1324         llvm::ConstantInt::get(base->getType(), 0)};
1325     for (int i = 1; i < lanes; ++i) {
1326       vals.push_back(llvm::ConstantExpr::getAdd(vals.back(), const_stride));
1327     }
1328 
1329     llvm::Value* offsets = llvm::ConstantVector::get(vals);
1330     llvm::Value* splat = irb_.CreateVectorSplat(lanes, base);
1331     value_ = irb_.CreateAdd(splat, offsets);
1332     return;
1333   }
1334 
1335   llvm::Type* vecType = nullptr;
1336   auto element_count = ElementCount(lanes);
1337   switch (v->dtype().scalar_type()) {
1338 #define TYPE_CASE(_1, Name)                                    \
1339   case ScalarType::Name:                                       \
1340     vecType = llvm::VectorType::get(Name##Ty_, element_count); \
1341     break;
1342     AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
1343 #undef TYPE_CASE
1344     case ScalarType::QInt8:
1345       vecType = llvm::VectorType::get(CharTy_, element_count);
1346       break;
1347     case ScalarType::QUInt8:
1348       vecType = llvm::VectorType::get(ByteTy_, element_count);
1349       break;
1350     case ScalarType::BFloat16:
1351       vecType = llvm::VectorType::get(ShortTy_, element_count);
1352       break;
1353     default:
1354       throw std::runtime_error("invalid dtype in Ramp");
1355   }
1356 
1357   value_ = llvm::UndefValue::get(vecType);
1358   for (int i = 0; i < lanes; ++i) {
1359     value_ = irb_.CreateInsertElement(value_, base, i);
1360     base = irb_.CreateAdd(base, stride);
1361   }
1362 }
emitUnmaskedLoad(llvm::Type * ty,llvm::Value * base,llvm::Value * idx)1363 llvm::Value* LLVMCodeGenImpl::emitUnmaskedLoad(
1364     llvm::Type* ty,
1365     llvm::Value* base,
1366     llvm::Value* idx) {
1367 #if LLVM_VERSION_MAJOR >= 15
1368   auto addr = irb_.CreateGEP(ty, base, idx);
1369   return irb_.CreateLoad(ty, addr);
1370 #else
1371   auto addr = irb_.CreateGEP(
1372       base->getType()->getScalarType()->getPointerElementType(), base, idx);
1373   return irb_.CreateLoad(addr->getType()->getPointerElementType(), addr);
1374 #endif
1375 }
1376 
emitMaskedLoad(llvm::Type * ty,llvm::Value * base,llvm::Value * idx,llvm::Value * mask)1377 llvm::Value* LLVMCodeGenImpl::emitMaskedLoad(
1378     llvm::Type* ty,
1379     llvm::Value* base,
1380     llvm::Value* idx,
1381     llvm::Value* mask) {
1382   // Create block structure for the masked load.
1383   auto preheader = irb_.GetInsertBlock();
1384   auto condblock = llvm::BasicBlock::Create(getContext(), "cond", fn_);
1385   auto tailblock = llvm::BasicBlock::Create(getContext(), "tail", fn_);
1386 
1387   // Test the mask
1388   auto cond = irb_.CreateICmpEQ(mask, llvm::ConstantInt::get(IntTy_, 1));
1389   irb_.CreateCondBr(cond, condblock, tailblock);
1390 
1391   // Do the load
1392   irb_.SetInsertPoint(condblock);
1393 
1394 #if LLVM_VERSION_MAJOR >= 15
1395   auto addr = irb_.CreateGEP(ty, base, idx);
1396   auto load = irb_.CreateLoad(ty, addr);
1397 #else
1398   auto addr = irb_.CreateGEP(
1399       base->getType()->getScalarType()->getPointerElementType(), base, idx);
1400   auto load = irb_.CreateLoad(addr->getType()->getPointerElementType(), addr);
1401 #endif
1402 
1403   irb_.CreateBr(tailblock);
1404 
1405   // Merge the masked and unmasked CFG edges
1406   irb_.SetInsertPoint(tailblock);
1407   auto phi = irb_.CreatePHI(load->getType(), 2);
1408   phi->addIncoming(llvm::UndefValue::get(load->getType()), preheader);
1409   phi->addIncoming(load, condblock);
1410 
1411   return phi;
1412 }
1413 
visit(const LoadPtr & v)1414 void LLVMCodeGenImpl::visit(const LoadPtr& v) {
1415   if (v->dtype().lanes() == 1) {
1416     v->base_handle()->accept(this);
1417     auto base = this->value_;
1418     v->flat_index()->accept(this);
1419     auto idx = this->value_;
1420     value_ = emitUnmaskedLoad(dtypeToLLVM(v->dtype()), base, idx);
1421     return;
1422   }
1423 
1424   llvm::Type* loadType = nullptr;
1425 
1426   auto element_count = ElementCount(v->dtype().lanes());
1427   switch (v->dtype().scalar_type()) {
1428 #define TYPE_CASE(_1, Name)                                     \
1429   case ScalarType::Name:                                        \
1430     loadType = llvm::VectorType::get(Name##Ty_, element_count); \
1431     break;
1432     AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
1433 #undef TYPE_CASE
1434     case ScalarType::QInt8:
1435       loadType = llvm::VectorType::get(CharTy_, element_count);
1436       break;
1437     case ScalarType::QUInt8:
1438       loadType = llvm::VectorType::get(ByteTy_, element_count);
1439       break;
1440     case ScalarType::BFloat16:
1441       loadType = llvm::VectorType::get(ShortTy_, element_count);
1442       break;
1443     default:
1444       throw std::runtime_error("invalid dtype in Load");
1445   }
1446 
1447   // Handle the case where the load is contiguous and unmasked efficiently
1448   auto idx_ramp = to<Ramp>(v->flat_index());
1449   if (idx_ramp) {
1450     auto stride_imm = intValue(idx_ramp->stride());
1451     if (stride_imm && *stride_imm == 1) {
1452       v->base_handle()->accept(this);
1453       auto base = this->value_;
1454       idx_ramp->base()->accept(this);
1455       auto first_idx = this->value_;
1456 
1457 #if LLVM_VERSION_MAJOR >= 15
1458       auto addr = irb_.CreateGEP(dtypeToLLVM(v->dtype()), base, first_idx);
1459 #else
1460       auto addr = irb_.CreateGEP(
1461           base->getType()->getScalarType()->getPointerElementType(),
1462           base,
1463           first_idx);
1464 #endif
1465 
1466       auto vaddr = irb_.CreateBitOrPointerCast(
1467           addr, llvm::PointerType::get(loadType, 0));
1468 #if LLVM_VERSION_MAJOR >= 12
1469       value_ = irb_.CreateAlignedLoad(loadType, vaddr, llvm::MaybeAlign(4));
1470 #else
1471       value_ = irb_.CreateAlignedLoad(loadType, vaddr, 4);
1472 #endif
1473       return;
1474     }
1475   }
1476 
1477   // Fallback to a scalar implementation
1478   v->base_handle()->accept(this);
1479   auto base = this->value_;
1480   v->flat_index()->accept(this);
1481   auto idx = this->value_;
1482 
1483   llvm::Value* load = llvm::UndefValue::get(loadType);
1484   for (int i = 0; i < v->dtype().lanes(); ++i) {
1485     auto sub_idx = irb_.CreateExtractElement(idx, i);
1486     llvm::Value* sub_load = nullptr;
1487     sub_load = emitUnmaskedLoad(dtypeToLLVM(v->dtype()), base, sub_idx);
1488     load = irb_.CreateInsertElement(load, sub_load, i);
1489   }
1490 
1491   value_ = load;
1492 }
1493 
1494 #if LLVM_VERSION_MAJOR >= 15
1495 // Pack the arguments into an aggregate struct for forwarding.
packFuncArgs(const std::vector<llvm::Value * > & func_args)1496 TypedPointer LLVMCodeGenImpl::packFuncArgs(
1497     const std::vector<llvm::Value*>& func_args) {
1498   if (func_args.empty()) {
1499     llvm::PointerType* VoidPtrType = llvm::PointerType::getUnqual(getContext());
1500     return TypedPointer(
1501         VoidPtrType, llvm::ConstantPointerNull::get(VoidPtrType));
1502   }
1503   std::vector<llvm::Type*> arg_types(func_args.size());
1504   for (const auto i : c10::irange(func_args.size())) {
1505     arg_types[i] = func_args[i]->getType();
1506   }
1507   llvm::StructType* packed_type = llvm::StructType::create(arg_types);
1508   llvm::Value* zero = llvm::ConstantInt::get(IntTy_, 0);
1509   llvm::Value* one = llvm::ConstantInt::get(IntTy_, 1);
1510   llvm::Value* packed = irb_.CreateAlloca(packed_type, one);
1511   for (const auto i : c10::irange(func_args.size())) {
1512     llvm::Value* dst_ptr = irb_.CreateInBoundsGEP(
1513         packed_type, packed, {zero, llvm::ConstantInt::get(IntTy_, i)});
1514     irb_.CreateStore(func_args[i], dst_ptr);
1515   }
1516   return TypedPointer(packed_type, packed);
1517 }
1518 
1519 // Unpack the aggregate struct into individual arguments.
unpackFuncArgs(TypedPointer packed,int arg_count)1520 std::vector<llvm::Value*> LLVMCodeGenImpl::unpackFuncArgs(
1521     TypedPointer packed,
1522     int arg_count) {
1523   // TODO: extract arg_count from packed.
1524   std::vector<llvm::Value*> func_args(arg_count);
1525   llvm::Value* zero = llvm::ConstantInt::get(IntTy_, 0);
1526   for (const auto i : c10::irange(arg_count)) {
1527     llvm::Type* feild_type = packed.type->getStructElementType(i);
1528     llvm::Value* feild_addr = irb_.CreateInBoundsGEP(
1529         packed.type, packed.addr, {zero, llvm::ConstantInt::get(IntTy_, i)});
1530     func_args[i] = irb_.CreateLoad(feild_type, feild_addr);
1531   }
1532   return func_args;
1533 }
1534 #else
1535 // Pack the arguments into an aggregate struct for forwarding.
packFuncArgs(const std::vector<llvm::Value * > & func_args)1536 llvm::Value* LLVMCodeGenImpl::packFuncArgs(
1537     const std::vector<llvm::Value*>& func_args) {
1538   if (func_args.empty()) {
1539     llvm::PointerType* VoidPtrType = llvm::Type::getInt8PtrTy(getContext());
1540     llvm::Constant* NullPtr = llvm::ConstantPointerNull::get(VoidPtrType);
1541     return NullPtr;
1542   }
1543   std::vector<llvm::Type*> arg_types(func_args.size());
1544   for (const auto i : c10::irange(func_args.size())) {
1545     arg_types[i] = func_args[i]->getType();
1546   }
1547   llvm::StructType* packed_type = llvm::StructType::create(arg_types);
1548   llvm::Value* zero = llvm::ConstantInt::get(IntTy_, 0);
1549   llvm::Value* one = llvm::ConstantInt::get(IntTy_, 1);
1550   llvm::Value* packed = irb_.CreateAlloca(packed_type, one);
1551   for (const auto i : c10::irange(func_args.size())) {
1552     llvm::Value* dst_ptr = irb_.CreateInBoundsGEP(
1553         packed_type, packed, {zero, llvm::ConstantInt::get(IntTy_, i)});
1554     irb_.CreateStore(func_args[i], dst_ptr);
1555   }
1556   return packed;
1557 }
1558 
1559 // Unpack the aggregate struct into individual arguments.
unpackFuncArgs(llvm::Value * packed,int arg_count)1560 std::vector<llvm::Value*> LLVMCodeGenImpl::unpackFuncArgs(
1561     llvm::Value* packed,
1562     int arg_count) {
1563   // TODO: extract arg_count from packed.
1564   std::vector<llvm::Value*> func_args(arg_count);
1565   llvm::Value* zero = llvm::ConstantInt::get(IntTy_, 0);
1566   for (const auto i : c10::irange(arg_count)) {
1567     llvm::Type* packed_type = packed->getType()->getPointerElementType();
1568     llvm::Value* dst_ptr = irb_.CreateInBoundsGEP(
1569         packed_type, packed, {zero, llvm::ConstantInt::get(IntTy_, i)});
1570     func_args[i] =
1571         irb_.CreateLoad(dst_ptr->getType()->getPointerElementType(), dst_ptr);
1572   }
1573   return func_args;
1574 }
1575 #endif
1576 
1577 // Lower the parallel for-loop.
1578 // * Move the body into its own closure.
1579 // * Identify var across the boundary into arguments and forward them.
1580 // * Send the closure and range to the dispatcher for execution.
processParallelFor(ForPtr v)1581 void LLVMCodeGenImpl::processParallelFor(ForPtr v) {
1582   // Create "start" and "stop" values.
1583   v->start()->accept(this);
1584   auto start = this->value_;
1585   v->stop()->accept(this);
1586   auto stop = this->value_;
1587 
1588   // The Vars that need to be forward in the body closure.
1589   std::vector<VarPtr> body_arg_vars;
1590   // Corresponding Value* that was used in the old body for the caller.
1591   std::vector<llvm::Value*> body_caller_vals;
1592   // Corresponding Value* that will be used in the new body closure.
1593   std::vector<llvm::Value*> body_closure_args;
1594 
1595   // Identify the VarPtr used in the body, and generated outside.
1596   VarFinder var_finder;
1597   v->body()->accept(&var_finder);
1598   auto& vars = var_finder.vars();
1599   for (auto& var : vars) {
1600     if (llvm::Value* value = varToValue(var)) {
1601       body_arg_vars.push_back(var);
1602       body_caller_vals.push_back(value);
1603     }
1604   }
1605 
1606   // Pack the arguments in an automatic variable for forwarding.
1607 #if LLVM_VERSION_MAJOR >= 15
1608   TypedPointer packData = packFuncArgs(body_caller_vals);
1609   llvm::Value* packed_caller_args = packData.addr;
1610 #else
1611   llvm::Value* packed_caller_args = packFuncArgs(body_caller_vals);
1612 #endif
1613   // Remember where we are before moving to the new function.
1614   llvm::BasicBlock* old_insert_block = irb_.GetInsertBlock();
1615 
1616   // Create the new body closure code.
1617 #if LLVM_VERSION_MAJOR >= 15
1618   auto func_type =
1619       llvm::FunctionType::get(VoidTy_, {LongTy_, OpqPtrTy_}, false);
1620 #else
1621   auto func_type =
1622       llvm::FunctionType::get(VoidTy_, {LongTy_, Int8PtrTy_}, false);
1623 #endif
1624 
1625   llvm::Function* func = llvm::Function::Create(
1626       func_type, llvm::Function::PrivateLinkage, "func", module_.get());
1627   auto func_body = llvm::BasicBlock::Create(getContext(), "func_body", func);
1628   irb_.SetInsertPoint(func_body);
1629   auto args = func->arg_begin();
1630   llvm::Value* index = args++;
1631   llvm::Value* packed_func_args_raw = args++;
1632   llvm::Value* packed_func_args = irb_.CreatePointerCast(
1633       packed_func_args_raw, packed_caller_args->getType());
1634 
1635   // Unpack the arguments from the opaque buffer.
1636   if (v->var()->dtype().scalar_type() != c10::kLong) {
1637     index = irb_.CreateIntCast(
1638         index, dtypeToLLVM(v->var()->dtype()), v->var()->dtype().is_signed());
1639   }
1640 #if LLVM_VERSION_MAJOR >= 15
1641   body_closure_args =
1642       unpackFuncArgs({packData.type, packed_func_args}, body_arg_vars.size());
1643 #else
1644   body_closure_args = unpackFuncArgs(packed_func_args, body_arg_vars.size());
1645 #endif
1646   // Set the codegen to the new func.
1647   // TODO: this should be replaced by RAII wrappers.
1648   varToVal_[v->var()] = index;
1649   replaceVarMapping(body_arg_vars, body_closure_args);
1650   llvm::Function* old_fn = fn_;
1651   fn_ = func;
1652   if (v->body()) {
1653     v->body()->accept(this);
1654   }
1655   // Restore back to the previous fn_
1656   fn_ = old_fn;
1657   irb_.CreateRet(nullptr);
1658   replaceVarMapping(body_arg_vars, body_caller_vals);
1659   varToVal_.erase(v->var());
1660 
1661   // Points back to the original block and generate the callee code.
1662   irb_.SetInsertPoint(old_insert_block);
1663 
1664 #if LLVM_VERSION_MAJOR >= 15
1665   llvm::Value* packed_caller_args_ptr =
1666       irb_.CreatePointerCast(packed_caller_args, OpqPtrTy_);
1667   llvm::Value* func_value = irb_.CreatePointerCast(func, OpqPtrTy_);
1668   llvm::FunctionType* dispatcher_fntype = llvm::FunctionType::get(
1669       VoidTy_, {OpqPtrTy_, LongTy_, LongTy_, OpqPtrTy_}, false);
1670 #else
1671   llvm::Value* packed_caller_args_ptr =
1672       irb_.CreatePointerCast(packed_caller_args, Int8PtrTy_);
1673   llvm::Value* func_value = irb_.CreatePointerCast(func, Int8PtrTy_);
1674   llvm::FunctionType* dispatcher_fntype = llvm::FunctionType::get(
1675       VoidTy_, {Int8PtrTy_, LongTy_, LongTy_, Int8PtrTy_}, false);
1676 #endif
1677 
1678   FunctionCallee dispatcher_callee =
1679       module_->getOrInsertFunction("DispatchParallel", dispatcher_fntype);
1680   llvm::Function* dispatcher =
1681       llvm::cast<llvm::Function>(dispatcher_callee.getCallee());
1682   dispatcher->addFnAttr(llvm::Attribute::NoUnwind);
1683   start = irb_.CreateIntCast(start, LongTy_, true);
1684   stop = irb_.CreateIntCast(stop, LongTy_, true);
1685   irb_.CreateCall(
1686       dispatcher, {func_value, start, stop, packed_caller_args_ptr});
1687   value_ = llvm::ConstantInt::get(IntTy_, 0);
1688 }
1689 
visit(const ForPtr & v)1690 void LLVMCodeGenImpl::visit(const ForPtr& v) {
1691   if (v->is_parallel()) {
1692     processParallelFor(v);
1693     return;
1694   }
1695 
1696   // Create "start" and "stop" values.
1697   v->start()->accept(this);
1698   auto start = this->value_;
1699   v->stop()->accept(this);
1700   auto stop = this->value_;
1701 
1702   // Create block for loop condition test.
1703   auto preheader = irb_.GetInsertBlock();
1704   auto condBlock = llvm::BasicBlock::Create(getContext(), "cond", fn_);
1705   irb_.CreateBr(condBlock);
1706   irb_.SetInsertPoint(condBlock);
1707 
1708   // Set up phi node for index variable.
1709   auto idx = irb_.CreatePHI(start->getType(), 2);
1710   idx->addIncoming(start, preheader);
1711   if (!varToVal_.count(v->var())) {
1712     varToVal_.emplace(v->var(), idx);
1713   } else {
1714     throw std::runtime_error("var should not exist before");
1715   }
1716 
1717   // Create the body and exit blocks.
1718   auto body = llvm::BasicBlock::Create(getContext(), "body", fn_);
1719   auto exit = llvm::BasicBlock::Create(getContext(), "exit", fn_);
1720 
1721   // Create the stop condition.
1722   auto cond = irb_.CreateICmpSLT(idx, stop);
1723   irb_.CreateCondBr(cond, body, exit);
1724 
1725   // Codegen the body.
1726   irb_.SetInsertPoint(body);
1727   if (v->body()) {
1728     v->body()->accept(this);
1729   }
1730   // "Body" block may have changed if we generated nested control flow.
1731   body = irb_.GetInsertBlock();
1732 
1733   // Increment the index variable and branch back to loop test.
1734   auto inc =
1735       irb_.CreateAdd(idx, llvm::ConstantInt::getSigned(start->getType(), 1));
1736   irb_.CreateBr(condBlock);
1737   idx->addIncoming(inc, body);
1738 
1739   // Exit the loop.
1740   irb_.SetInsertPoint(exit);
1741 
1742   varToVal_.erase(v->var());
1743   value_ = llvm::ConstantInt::get(IntTy_, 0);
1744 }
1745 
visit(const BlockPtr & v)1746 void LLVMCodeGenImpl::visit(const BlockPtr& v) {
1747   BlockPtr last = scope_;
1748   scope_ = v;
1749 
1750   for (StmtPtr s : *v) {
1751     s->accept(this);
1752   }
1753 
1754   scope_ = last;
1755 
1756   auto it = scopeToVar_.find(v);
1757   if (it != scopeToVar_.end()) {
1758     for (VarPtr e : it->second) {
1759       if (varToVal_.erase(e) != 1) {
1760         throw std::runtime_error("erasing var that doesn't exist");
1761       }
1762     }
1763   }
1764 }
1765 
emitUnmaskedStore(llvm::Type * ty,llvm::Value * base,llvm::Value * idx,llvm::Value * val)1766 void LLVMCodeGenImpl::emitUnmaskedStore(
1767     llvm::Type* ty,
1768     llvm::Value* base,
1769     llvm::Value* idx,
1770     llvm::Value* val) {
1771 #if LLVM_VERSION_MAJOR >= 15
1772   auto addr = irb_.CreateGEP(ty, base, idx);
1773 #else
1774   auto addr = irb_.CreateGEP(
1775       base->getType()->getScalarType()->getPointerElementType(), base, idx);
1776 #endif
1777 
1778   irb_.CreateStore(val, addr);
1779 }
1780 
emitMaskedStore(llvm::Type * ty,llvm::Value * base,llvm::Value * idx,llvm::Value * mask,llvm::Value * val)1781 void LLVMCodeGenImpl::emitMaskedStore(
1782     llvm::Type* ty,
1783     llvm::Value* base,
1784     llvm::Value* idx,
1785     llvm::Value* mask,
1786     llvm::Value* val) {
1787   // Create block structure for the masked store.
1788   auto condblock = llvm::BasicBlock::Create(getContext(), "cond", fn_);
1789   auto tailblock = llvm::BasicBlock::Create(getContext(), "tail", fn_);
1790 
1791   // Test the mask
1792   auto cond = irb_.CreateICmpEQ(mask, llvm::ConstantInt::get(IntTy_, 1));
1793   irb_.CreateCondBr(cond, condblock, tailblock);
1794 
1795   // Do the store
1796   irb_.SetInsertPoint(condblock);
1797 
1798 #if LLVM_VERSION_MAJOR >= 15
1799   auto addr = irb_.CreateGEP(ty, base, idx);
1800 #else
1801   auto addr = irb_.CreateGEP(
1802       base->getType()->getScalarType()->getPointerElementType(), base, idx);
1803 #endif
1804 
1805   irb_.CreateStore(val, addr);
1806   irb_.CreateBr(tailblock);
1807 
1808   // Merge the masked and unmasked CFG edges
1809   irb_.SetInsertPoint(tailblock);
1810 }
1811 
visit(const StorePtr & v)1812 void LLVMCodeGenImpl::visit(const StorePtr& v) {
1813   if (v->value()->dtype().lanes() == 1) {
1814     v->base_handle()->accept(this);
1815     auto base = this->value_;
1816     v->flat_index()->accept(this);
1817     auto idx = this->value_;
1818     v->value()->accept(this);
1819     auto val = this->value_;
1820 
1821     emitUnmaskedStore(dtypeToLLVM(v->value()->dtype()), base, idx, val);
1822     value_ = llvm::ConstantInt::get(IntTy_, 0);
1823     return;
1824   }
1825 
1826   v->base_handle()->accept(this);
1827   auto base = this->value_;
1828   v->value()->accept(this);
1829   auto val = this->value_;
1830 
1831   // Handle the case where the store is contiguous and unmasked efficiently
1832   auto idx_ramp = to<Ramp>(v->flat_index());
1833   if (idx_ramp) {
1834     auto stride_imm = intValue(idx_ramp->stride());
1835     if (stride_imm && *stride_imm == 1) {
1836       idx_ramp->base()->accept(this);
1837       auto first_idx = value_;
1838 
1839 #if LLVM_VERSION_MAJOR >= 15
1840       auto addr =
1841           irb_.CreateGEP(dtypeToLLVM(v->value()->dtype()), base, first_idx);
1842 #else
1843       auto addr = irb_.CreateGEP(
1844           base->getType()->getScalarType()->getPointerElementType(),
1845           base,
1846           first_idx);
1847 #endif
1848 
1849       auto vaddr = irb_.CreateBitOrPointerCast(
1850           addr, llvm::PointerType::get(val->getType(), 0));
1851 
1852 #if LLVM_VERSION_MAJOR >= 13
1853       irb_.CreateAlignedStore(val, vaddr, llvm::MaybeAlign(4));
1854 #else
1855       irb_.CreateAlignedStore(val, vaddr, 4);
1856 #endif
1857       value_ = llvm::ConstantInt::get(IntTy_, 0);
1858       return;
1859     }
1860   }
1861 
1862   v->flat_index()->accept(this);
1863   auto idx = this->value_;
1864 
1865   // Fallback to a scalar implementation
1866   for (int i = 0; i < v->value()->dtype().lanes(); ++i) {
1867     auto sub_idx = irb_.CreateExtractElement(idx, i);
1868     auto sub_val = irb_.CreateExtractElement(val, i);
1869     emitUnmaskedStore(dtypeToLLVM(v->value()->dtype()), base, sub_idx, sub_val);
1870   }
1871 
1872   value_ = llvm::ConstantInt::get(IntTy_, 0);
1873 }
1874 
visit(const BroadcastPtr & v)1875 void LLVMCodeGenImpl::visit(const BroadcastPtr& v) {
1876   v->value()->accept(this);
1877   int lanes = v->lanes();
1878   value_ = irb_.CreateVectorSplat(lanes, value_);
1879 }
1880 
visit(const IfThenElsePtr & v)1881 void LLVMCodeGenImpl::visit(const IfThenElsePtr& v) {
1882   v->condition()->accept(this);
1883   llvm::Value* condition = value_;
1884   llvm::Value* c = irb_.CreateICmpNE(
1885       condition, llvm::ConstantInt::get(condition->getType(), 0));
1886 
1887   auto then_block = llvm::BasicBlock::Create(getContext(), "then", fn_);
1888   auto else_block = llvm::BasicBlock::Create(getContext(), "else", fn_);
1889   auto end_block = llvm::BasicBlock::Create(getContext(), "block", fn_);
1890   irb_.CreateCondBr(c, then_block, else_block);
1891 
1892   irb_.SetInsertPoint(then_block);
1893   v->true_value()->accept(this);
1894   llvm::Value* then_val = value_;
1895   then_block = irb_.GetInsertBlock();
1896   irb_.CreateBr(end_block);
1897 
1898   irb_.SetInsertPoint(else_block);
1899   v->false_value()->accept(this);
1900   llvm::Value* else_val = value_;
1901   else_block = irb_.GetInsertBlock();
1902   irb_.CreateBr(end_block);
1903 
1904   irb_.SetInsertPoint(end_block);
1905   llvm::PHINode* phi = irb_.CreatePHI(then_val->getType(), 2);
1906   phi->addIncoming(then_val, then_block);
1907   phi->addIncoming(else_val, else_block);
1908   value_ = phi;
1909 }
1910 
applyMathFunctionAttributes(llvm::Function * f)1911 static void applyMathFunctionAttributes(llvm::Function* f) {
1912   f->addFnAttr(llvm::Attribute::ReadNone);
1913   f->addFnAttr(llvm::Attribute::NoUnwind);
1914   // TODO: Adding this attr should be correct, but as of LLVM 9.0.1 adding it
1915   // causes some math functions to incorrectly be turned into tail calls.
1916   // f->addFnAttr(llvm::Attribute::Speculatable);
1917 #if LLVM_VERSION_MAJOR >= 9
1918   f->addFnAttr(llvm::Attribute::NoFree);
1919   f->addFnAttr(llvm::Attribute::WillReturn);
1920 #endif
1921 }
1922 
toVec(llvm::Value * v,int lanes)1923 llvm::Value* LLVMCodeGenImpl::toVec(llvm::Value* v, int lanes) {
1924   if (lanes > 1) {
1925     return irb_.CreateVectorSplat(lanes, v);
1926   } else {
1927     return v;
1928   }
1929 }
1930 
emitIsNan(IntrinsicsPtr v)1931 void LLVMCodeGenImpl::emitIsNan(IntrinsicsPtr v) {
1932   v->param(0)->accept(this);
1933   llvm::Type* dstType = dtypeToLLVM(v->dtype());
1934   if (!v->param(0)->dtype().is_floating_point()) {
1935     value_ = toVec(llvm::ConstantInt::get(dstType, 0), v->dtype().lanes());
1936   } else {
1937     TORCH_INTERNAL_ASSERT(
1938         v->dtype().scalar_type() == ScalarType::Int,
1939         buildErrorMessage(
1940             "Unexpected non-Int dtype of Intrinsics' result value in the fuser."));
1941     auto is_nan = irb_.CreateFCmpUNO(
1942         value_, llvm::ConstantFP::get(value_->getType(), 0.));
1943     if (v->dtype().lanes() > 1) {
1944       dstType =
1945           llvm::VectorType::get(dstType, ElementCount(v->dtype().lanes()));
1946     }
1947     value_ = irb_.CreateIntCast(is_nan, dstType, /*isSigned*/ false);
1948   }
1949 }
1950 
wantSleef(const std::string & name)1951 static bool wantSleef(const std::string& name) {
1952   // Using sleef on these ops is slower than libm.
1953   static std::unordered_set<std::string> noSleef = {
1954       "sqrt",
1955       "ceil",
1956       "trunc",
1957       "fabs",
1958       "floor",
1959       "sqrtf",
1960       "ceilf",
1961       "truncf",
1962       "fabsf",
1963       "floorf",
1964   };
1965   return noSleef.find(name) == noSleef.end();
1966 }
1967 
getSimdFunction(const std::string & basename,llvm::Type * basetype,Arity arity,int lanes)1968 LLVMCodeGenImpl::SimdCallee LLVMCodeGenImpl::getSimdFunction(
1969     const std::string& basename,
1970     llvm::Type* basetype,
1971     Arity arity,
1972     int lanes) {
1973   std::string name;
1974   llvm::Type* type;
1975   bool useSimd;
1976 
1977   // Determine whether to use vectorized intrinsic.
1978   auto const& featureString = jit_->getTargetMachine().getTargetFeatureString();
1979   bool hasAVX = featureString.find("+avx") != llvm::StringRef::npos;
1980   std::string typeSuffix = basetype == DoubleTy_ ? "d" : "";
1981   std::string sleefName =
1982       "Sleef_" + basename + typeSuffix + std::to_string(lanes);
1983   if (wantSleef(basename) && hasAVX && jit_->hasSymbol(sleefName)) {
1984     name = std::move(sleefName);
1985     type = llvm::VectorType::get(basetype, ElementCount(lanes));
1986     useSimd = true;
1987   } else {
1988     name = basename;
1989     type = basetype;
1990     useSimd = false;
1991   }
1992 
1993   // Get function to call from name and type.
1994   llvm::FunctionType* fntype;
1995   switch (arity) {
1996     case Unary:
1997       fntype = llvm::FunctionType::get(type, {type}, false);
1998       break;
1999     case Binary:
2000       fntype = llvm::FunctionType::get(type, {type, type}, false);
2001       break;
2002   }
2003   FunctionCallee callee = module_->getOrInsertFunction(name, fntype, {});
2004   applyMathFunctionAttributes(llvm::cast<llvm::Function>(callee.getCallee()));
2005   return SimdCallee{callee.getFunctionType(), callee.getCallee(), useSimd};
2006 }
2007 
visit(const IntrinsicsPtr & v)2008 void LLVMCodeGenImpl::visit(const IntrinsicsPtr& v) {
2009   llvm::FunctionType* call_ty = nullptr;
2010   llvm::Value* call_fn = nullptr;
2011   bool call_simd_sleef = false;
2012 
2013   if (v->op_type() == kIsNan) {
2014     return emitIsNan(v);
2015   }
2016 
2017   if (v->dtype().scalar_type() == ScalarType::Float) {
2018     switch (v->op_type()) {
2019       case kRsqrt: {
2020         v->params().front()->accept(this);
2021         value_ = irb_.CreateUnaryIntrinsic(llvm::Intrinsic::sqrt, value_);
2022         llvm::Value* constant =
2023             toVec(llvm::ConstantFP::get(FloatTy_, 1.0), v->dtype().lanes());
2024         value_ = irb_.CreateFDiv(constant, value_);
2025         return;
2026       } break;
2027 
2028 #define SIMD_UNARY_MATH_CASE(enum, name, type)                  \
2029   case enum: {                                                  \
2030     std::tie(call_ty, call_fn, call_simd_sleef) =               \
2031         getSimdFunction(name, type, Unary, v->dtype().lanes()); \
2032   } break;
2033         SIMD_UNARY_MATH_CASE(kLog10, "log10f", FloatTy_)
2034         SIMD_UNARY_MATH_CASE(kLog, "logf", FloatTy_)
2035         SIMD_UNARY_MATH_CASE(kLog1p, "log1pf", FloatTy_)
2036         SIMD_UNARY_MATH_CASE(kLog2, "log2f", FloatTy_)
2037         SIMD_UNARY_MATH_CASE(kExp, "expf", FloatTy_)
2038         SIMD_UNARY_MATH_CASE(kCos, "cosf", FloatTy_)
2039         SIMD_UNARY_MATH_CASE(kSin, "sinf", FloatTy_)
2040         SIMD_UNARY_MATH_CASE(kSqrt, "sqrtf", FloatTy_)
2041         SIMD_UNARY_MATH_CASE(kAbs, "fabsf", FloatTy_)
2042         SIMD_UNARY_MATH_CASE(kFloor, "floorf", FloatTy_)
2043         SIMD_UNARY_MATH_CASE(kCeil, "ceilf", FloatTy_)
2044         SIMD_UNARY_MATH_CASE(kTrunc, "truncf", FloatTy_)
2045         SIMD_UNARY_MATH_CASE(kRound, "nearbyint", FloatTy_)
2046         SIMD_UNARY_MATH_CASE(kErf, "erff", FloatTy_)
2047         SIMD_UNARY_MATH_CASE(kErfc, "erfcf", FloatTy_)
2048         SIMD_UNARY_MATH_CASE(kTan, "tanf", FloatTy_)
2049         SIMD_UNARY_MATH_CASE(kAcos, "acosf", FloatTy_)
2050         SIMD_UNARY_MATH_CASE(kAsin, "asinf", FloatTy_)
2051         SIMD_UNARY_MATH_CASE(kAtan, "atanf", FloatTy_)
2052         SIMD_UNARY_MATH_CASE(kCosh, "coshf", FloatTy_)
2053         SIMD_UNARY_MATH_CASE(kSinh, "sinhf", FloatTy_)
2054         SIMD_UNARY_MATH_CASE(kTanh, "tanhf", FloatTy_)
2055         SIMD_UNARY_MATH_CASE(kExpm1, "expm1f", FloatTy_)
2056         SIMD_UNARY_MATH_CASE(kLgamma, "lgammaf", FloatTy_)
2057 #undef SIMD_UNARY_MATH_CASE
2058 
2059 #define SIMD_BINARY_MATH_CASE(enum, name, type)                  \
2060   case enum: {                                                   \
2061     std::tie(call_ty, call_fn, call_simd_sleef) =                \
2062         getSimdFunction(name, type, Binary, v->dtype().lanes()); \
2063   } break;
2064         SIMD_BINARY_MATH_CASE(kAtan2, "atan2f", FloatTy_)
2065         SIMD_BINARY_MATH_CASE(kPow, "powf", FloatTy_)
2066         SIMD_BINARY_MATH_CASE(kFmod, "fmodf", FloatTy_)
2067 #undef SIMD_BINARY_MATH_CASE
2068 
2069       case kRemainder: {
2070         FunctionCallee callee = module_->getOrInsertFunction(
2071             "remainderf",
2072             llvm::FunctionType::get(FloatTy_, {FloatTy_, FloatTy_}, false),
2073             {});
2074         call_ty = callee.getFunctionType();
2075         call_fn = callee.getCallee();
2076         applyMathFunctionAttributes(llvm::cast<llvm::Function>(call_fn));
2077       } break;
2078 
2079       default: {
2080         throw unimplemented_lowering(v);
2081       } break;
2082     }
2083 
2084   } else if (v->dtype().scalar_type() == ScalarType::Double) {
2085     switch (v->op_type()) {
2086 #define SIMD_UNARY_MATH_CASE(enum, name, type)                  \
2087   case enum: {                                                  \
2088     std::tie(call_ty, call_fn, call_simd_sleef) =               \
2089         getSimdFunction(name, type, Unary, v->dtype().lanes()); \
2090   } break;
2091       SIMD_UNARY_MATH_CASE(kLog10, "log10", DoubleTy_)
2092       SIMD_UNARY_MATH_CASE(kLog, "log", DoubleTy_)
2093       SIMD_UNARY_MATH_CASE(kLog1p, "log1p", DoubleTy_)
2094       SIMD_UNARY_MATH_CASE(kLog2, "log2", DoubleTy_)
2095       SIMD_UNARY_MATH_CASE(kExp, "exp", DoubleTy_)
2096       SIMD_UNARY_MATH_CASE(kCos, "cos", DoubleTy_)
2097       SIMD_UNARY_MATH_CASE(kSin, "sin", DoubleTy_)
2098       SIMD_UNARY_MATH_CASE(kSqrt, "sqrt", DoubleTy_)
2099       SIMD_UNARY_MATH_CASE(kAbs, "fabs", DoubleTy_)
2100       SIMD_UNARY_MATH_CASE(kFloor, "floor", DoubleTy_)
2101       SIMD_UNARY_MATH_CASE(kCeil, "ceil", DoubleTy_)
2102       SIMD_UNARY_MATH_CASE(kTrunc, "trunc", DoubleTy_)
2103       SIMD_UNARY_MATH_CASE(kRound, "nearbyint", DoubleTy_)
2104       SIMD_UNARY_MATH_CASE(kErf, "erf", DoubleTy_)
2105       SIMD_UNARY_MATH_CASE(kErfc, "erfc", DoubleTy_)
2106       SIMD_UNARY_MATH_CASE(kTan, "tan", DoubleTy_)
2107       SIMD_UNARY_MATH_CASE(kAcos, "acos", DoubleTy_)
2108       SIMD_UNARY_MATH_CASE(kAsin, "asin", DoubleTy_)
2109       SIMD_UNARY_MATH_CASE(kAtan, "atan", DoubleTy_)
2110       SIMD_UNARY_MATH_CASE(kCosh, "cosh", DoubleTy_)
2111       SIMD_UNARY_MATH_CASE(kSinh, "sinh", DoubleTy_)
2112       SIMD_UNARY_MATH_CASE(kTanh, "tanh", DoubleTy_)
2113       SIMD_UNARY_MATH_CASE(kExpm1, "expm1", DoubleTy_)
2114       SIMD_UNARY_MATH_CASE(kLgamma, "lgamma", DoubleTy_)
2115 #undef SIMD_UNARY_MATH_CASE
2116 
2117       case kRsqrt: {
2118         v->params().front()->accept(this);
2119         value_ = irb_.CreateUnaryIntrinsic(llvm::Intrinsic::sqrt, value_);
2120         llvm::Value* constant = llvm::ConstantFP::get(DoubleTy_, 1.0);
2121         if (v->dtype().lanes() > 1) {
2122           constant = irb_.CreateVectorSplat(v->dtype().lanes(), constant);
2123         }
2124         value_ = irb_.CreateFDiv(constant, value_);
2125         return;
2126       } break;
2127 
2128 #define SIMD_BINARY_MATH_CASE(enum, name, type)                  \
2129   case enum: {                                                   \
2130     std::tie(call_ty, call_fn, call_simd_sleef) =                \
2131         getSimdFunction(name, type, Binary, v->dtype().lanes()); \
2132   } break;
2133         SIMD_BINARY_MATH_CASE(kAtan2, "atan2", DoubleTy_)
2134         SIMD_BINARY_MATH_CASE(kPow, "pow", DoubleTy_)
2135         SIMD_BINARY_MATH_CASE(kFmod, "fmod", DoubleTy_)
2136 #undef SIMD_BINARY_MATH_CASE
2137 
2138       case kRemainder: {
2139         FunctionCallee callee = module_->getOrInsertFunction(
2140             "remainder",
2141             llvm::FunctionType::get(DoubleTy_, {DoubleTy_, DoubleTy_}, false),
2142             {});
2143         call_ty = callee.getFunctionType();
2144         call_fn = callee.getCallee();
2145         applyMathFunctionAttributes(llvm::cast<llvm::Function>(call_fn));
2146       } break;
2147 
2148       default: {
2149         throw unimplemented_lowering(v);
2150       } break;
2151     }
2152   } else if (v->dtype().is_integral() && v->op_type() == kAbs) {
2153     // abs is only intrinsic defined for integer inputs in pytorch eager
2154     v->params().front()->accept(this);
2155     if (!v->dtype().is_signed()) {
2156       return;
2157     }
2158     // TODO: use llvm.abs intrinsic for LLVM 12
2159     auto zero = llvm::ConstantInt::get(value_->getType(), 0);
2160     auto neg_value = irb_.CreateSub(zero, value_);
2161     auto icmp = irb_.CreateICmpSGT(value_, zero);
2162     value_ = irb_.CreateSelect(icmp, value_, neg_value);
2163     return;
2164   } else {
2165     TORCH_INTERNAL_ASSERT(
2166         false,
2167         buildErrorMessage(
2168             std::string("Unimplemented lowering for intrinsic '") +
2169             std::to_string(v->op_type()) + "' for input of dtype " +
2170             std::to_string(v->dtype().scalar_dtype()) +
2171             " in LLVM codegen of the fuser."));
2172   }
2173 
2174   std::vector<llvm::Value*> params;
2175   for (auto& p : v->params()) {
2176     p->accept(this);
2177     params.push_back(value_);
2178   }
2179 
2180   if (v->dtype().lanes() == 1 || call_simd_sleef == true) {
2181     value_ = irb_.CreateCall(call_ty, call_fn, params);
2182   } else {
2183     llvm::Type* vecType = params[0]->getType();
2184     value_ = llvm::UndefValue::get(vecType);
2185     for (int i = 0; i < v->dtype().lanes(); ++i) {
2186       std::vector<llvm::Value*> call_operands;
2187       for (auto p : params) {
2188         call_operands.push_back(irb_.CreateExtractElement(p, i));
2189       }
2190 
2191       llvm::Value* val = irb_.CreateCall(call_ty, call_fn, call_operands);
2192       value_ = irb_.CreateInsertElement(value_, val, i);
2193     }
2194   }
2195 }
2196 
handleBufReuse(BufPtr buf,BufPtr buf_to_reuse)2197 void LLVMCodeGenImpl::handleBufReuse(BufPtr buf, BufPtr buf_to_reuse) {
2198   llvm::Value* ptr = varToVal_.at(buf_to_reuse->base_handle());
2199   if (buf_to_reuse->dtype().scalar_type() != buf->dtype().scalar_type()) {
2200     ptr = irb_.CreatePointerCast(ptr, dtypeToLLVMPtr(buf->dtype()));
2201   }
2202   varToVal_[buf->base_handle()] = ptr;
2203 }
2204 
visit(const ExternalCallPtr & v)2205 void LLVMCodeGenImpl::visit(const ExternalCallPtr& v) {
2206   auto& func_registry = getNNCFunctionRegistry();
2207   if (!func_registry.count(v->func_name())) {
2208     throw unimplemented_lowering(v);
2209   }
2210 
2211   // Prepare a vector of bufs that we need to pass to the external function.
2212   // This vector is the output buf followed by the buf_args.
2213   std::vector<BufPtr> bufs(v->buf_args());
2214   bufs.insert(bufs.begin(), v->buf());
2215 
2216   int64_t bufs_num = bufs.size();
2217   int64_t args_num = v->args().size();
2218 
2219   // Count the size of dims array - it consists of dimension of all bufs
2220   // concatenated together.
2221   int64_t dims_num = 0;
2222   for (BufPtr b : bufs) {
2223     dims_num += b->dims().size();
2224   }
2225 #if LLVM_VERSION_MAJOR >= 15
2226   llvm::Value* buf_ptrs = irb_.CreateAlloca(
2227       OpqPtrTy_, llvm::ConstantInt::getSigned(IntTy_, bufs_num));
2228 #else
2229   llvm::Value* buf_ptrs = irb_.CreateAlloca(
2230       Int8PtrTy_, llvm::ConstantInt::getSigned(IntTy_, bufs_num));
2231 #endif
2232   llvm::Value* buf_ranks = irb_.CreateAlloca(
2233       LongTy_, llvm::ConstantInt::getSigned(IntTy_, bufs_num));
2234   llvm::Value* buf_dims = irb_.CreateAlloca(
2235       LongTy_, llvm::ConstantInt::getSigned(IntTy_, dims_num));
2236   llvm::Value* buf_strides = irb_.CreateAlloca(
2237       LongTy_, llvm::ConstantInt::getSigned(IntTy_, dims_num));
2238   llvm::Value* buf_dtypes = irb_.CreateAlloca(
2239       ByteTy_, llvm::ConstantInt::getSigned(IntTy_, bufs_num));
2240   llvm::Value* extra_args = irb_.CreateAlloca(
2241       LongTy_, llvm::ConstantInt::getSigned(IntTy_, args_num));
2242 
2243   int i = 0;
2244   int dim_idx = 0;
2245   int stride_idx = 0;
2246   for (BufPtr b : bufs) {
2247     // Store value for buf pointer
2248     b->base_handle()->accept(this);
2249     auto buf_ptr = this->value_;
2250 #if LLVM_VERSION_MAJOR >= 15
2251     auto gep = irb_.CreateInBoundsGEP(
2252         OpqPtrTy_, buf_ptrs, llvm::ConstantInt::getSigned(IntTy_, i));
2253     auto buf_void_ptr = irb_.CreatePointerCast(buf_ptr, OpqPtrTy_);
2254 #else
2255     auto gep = irb_.CreateInBoundsGEP(
2256         Int8PtrTy_, buf_ptrs, llvm::ConstantInt::getSigned(IntTy_, i));
2257     auto buf_void_ptr = irb_.CreatePointerCast(buf_ptr, Int8PtrTy_);
2258 #endif
2259     irb_.CreateStore(buf_void_ptr, gep);
2260 
2261     // Store dtype of the buf
2262     gep = irb_.CreateInBoundsGEP(
2263         ByteTy_, buf_dtypes, llvm::ConstantInt::getSigned(IntTy_, i));
2264     irb_.CreateStore(
2265         llvm::ConstantInt::getSigned(ByteTy_, (int8_t)b->dtype().scalar_type()),
2266         gep);
2267 
2268     // Store rank of the buf
2269     gep = irb_.CreateInBoundsGEP(
2270         LongTy_, buf_ranks, llvm::ConstantInt::getSigned(IntTy_, i));
2271     irb_.CreateStore(
2272         llvm::ConstantInt::getSigned(LongTy_, b->dims().size()), gep);
2273 
2274     // Store dims of the buf
2275     for (const auto dim : c10::irange(b->dims().size())) {
2276       gep = irb_.CreateInBoundsGEP(
2277           LongTy_, buf_dims, llvm::ConstantInt::getSigned(IntTy_, dim_idx));
2278       b->dims()[dim]->accept(this);
2279       auto dim_val = this->value_;
2280       irb_.CreateStore(irb_.CreateZExt(dim_val, LongTy_), gep);
2281       dim_idx++;
2282     }
2283 
2284     // Store strides of the buf
2285     for (const auto dim : c10::irange(b->dims().size())) {
2286       gep = irb_.CreateInBoundsGEP(
2287           LongTy_,
2288           buf_strides,
2289           llvm::ConstantInt::getSigned(IntTy_, stride_idx));
2290       b->strides()[dim]->accept(this);
2291       auto stride_val = this->value_;
2292       irb_.CreateStore(irb_.CreateZExt(stride_val, LongTy_), gep);
2293       stride_idx++;
2294     }
2295 
2296     i++;
2297   }
2298 
2299   i = 0;
2300   for (ExprPtr arg : v->args()) {
2301     auto gep = irb_.CreateInBoundsGEP(
2302         LongTy_, extra_args, llvm::ConstantInt::getSigned(IntTy_, i));
2303     arg->accept(this);
2304     irb_.CreateStore(irb_.CreateZExtOrBitCast(this->value_, LongTy_), gep);
2305     i++;
2306   }
2307 
2308   // Generate the call itself
2309   std::string fname = v->func_name();
2310 #if LLVM_VERSION_MAJOR >= 15
2311   FunctionCallee callee = module_->getOrInsertFunction(
2312       fname,
2313       llvm::FunctionType::get(
2314           llvm::Type::getVoidTy(getContext()), // return type
2315           {LongTy_, // int64_t bufs_num
2316            OpqPtrTy_, // void** buf_data
2317            OpqPtrTy_, // int64_t* buf_ranks
2318            OpqPtrTy_, // int64_t* buf_dims
2319            OpqPtrTy_, // int64_t* buf_strides
2320            OpqPtrTy_, // int64_t* buf_dtypes
2321            LongTy_, // int64_t args_num
2322            OpqPtrTy_}, // int64_t* extra_args
2323           false)); // is var_arg
2324 #else
2325   FunctionCallee callee = module_->getOrInsertFunction(
2326       fname,
2327       llvm::FunctionType::get(
2328           llvm::Type::getVoidTy(getContext()), // return type
2329           {LongTy_, // int64_t bufs_num
2330            Int8PtrTy_->getPointerTo(), // void** buf_data
2331            LongTy_->getPointerTo(), // int64_t* buf_ranks
2332            LongTy_->getPointerTo(), // int64_t* buf_dims
2333            LongTy_->getPointerTo(), // int64_t* buf_strides
2334            ByteTy_->getPointerTo(), // int64_t* buf_dtypes
2335            LongTy_, // int64_t args_num
2336            LongTy_->getPointerTo()}, // int64_t* extra_args
2337           false)); // is var_arg
2338 #endif
2339 
2340   auto call_ty = callee.getFunctionType();
2341   auto call_fn = callee.getCallee();
2342   llvm::cast<llvm::Function>(call_fn)->addFnAttr(llvm::Attribute::NoUnwind);
2343 
2344   irb_.CreateCall(
2345       call_ty,
2346       call_fn,
2347       {llvm::ConstantInt::getSigned(LongTy_, bufs_num),
2348        buf_ptrs,
2349        buf_ranks,
2350        buf_dims,
2351        buf_strides,
2352        buf_dtypes,
2353        llvm::ConstantInt::getSigned(LongTy_, args_num),
2354        extra_args});
2355 
2356   value_ = llvm::ConstantInt::get(IntTy_, 0);
2357 }
2358 
visit(const ExternalCallWithAllocPtr & v)2359 void LLVMCodeGenImpl::visit(const ExternalCallWithAllocPtr& v) {
2360   auto& func_registry = getNNCFunctionRegistry();
2361   if (!func_registry.count(v->func_name())) {
2362     throw unimplemented_lowering(v);
2363   }
2364 
2365   const auto& bufs_out = v->buf_out_args();
2366   const auto& bufs_in = v->buf_args();
2367 
2368   const auto bufs_in_size = bufs_in.size();
2369   const auto bufs_out_size = bufs_out.size();
2370   const auto args_num = v->args().size();
2371 
2372   // Count the size of dims array - it consists of dimension of all bufs
2373   // concatenated together.
2374   size_t dims_num = 0;
2375   for (const auto& b : bufs_in) {
2376     dims_num += b->dims().size();
2377   }
2378 
2379   // bufs_out_size for out tensors data pointers
2380   // bufs_in_size for input pointers
2381   // bufs_out_size for out tensors TensorImpl* to pass to nnc_aten_free to
2382   // release out tensors
2383 #if LLVM_VERSION_MAJOR >= 15
2384   llvm::Value* buf_ptrs = irb_.CreateAlloca(
2385       OpqPtrTy_,
2386       llvm::ConstantInt::getSigned(IntTy_, bufs_in_size + 2 * bufs_out_size));
2387 #else
2388   llvm::Value* buf_ptrs = irb_.CreateAlloca(
2389       Int8PtrTy_,
2390       llvm::ConstantInt::getSigned(IntTy_, bufs_in_size + 2 * bufs_out_size));
2391 #endif
2392   // @lint-ignore CLANGTIDY
2393   llvm::Value* buf_ranks = irb_.CreateAlloca(
2394       LongTy_, llvm::ConstantInt::getSigned(IntTy_, bufs_in_size));
2395   llvm::Value* buf_dims = irb_.CreateAlloca(
2396       LongTy_, llvm::ConstantInt::getSigned(IntTy_, dims_num));
2397   llvm::Value* buf_strides = irb_.CreateAlloca(
2398       LongTy_, llvm::ConstantInt::getSigned(IntTy_, dims_num));
2399   llvm::Value* buf_dtypes = irb_.CreateAlloca(
2400       ByteTy_, llvm::ConstantInt::getSigned(IntTy_, bufs_in_size));
2401   // @lint-ignore CLANGTIDY
2402   llvm::Value* extra_args = irb_.CreateAlloca(
2403       LongTy_, llvm::ConstantInt::getSigned(IntTy_, args_num));
2404 
2405   int i = 0;
2406   int dim_idx = 0;
2407   int stride_idx = 0;
2408   for (const auto& b : bufs_in) {
2409     // Store value for buf pointer
2410     b->base_handle()->accept(this);
2411     auto buf_ptr = this->value_;
2412 
2413 #if LLVM_VERSION_MAJOR >= 15
2414     llvm::Value* gep = irb_.CreateInBoundsGEP(
2415         OpqPtrTy_,
2416         buf_ptrs,
2417         // @lint-ignore CLANGTIDY
2418         llvm::ConstantInt::getSigned(IntTy_, bufs_out_size + i));
2419     auto buf_void_ptr = irb_.CreatePointerCast(buf_ptr, OpqPtrTy_);
2420 #else
2421     llvm::Value* gep = irb_.CreateInBoundsGEP(
2422         Int8PtrTy_,
2423         buf_ptrs,
2424         // @lint-ignore CLANGTIDY
2425         llvm::ConstantInt::getSigned(IntTy_, bufs_out_size + i));
2426     auto buf_void_ptr = irb_.CreatePointerCast(buf_ptr, Int8PtrTy_);
2427 #endif
2428 
2429     irb_.CreateStore(buf_void_ptr, gep);
2430 
2431     // Store dtype of the buf
2432     gep = irb_.CreateInBoundsGEP(
2433         ByteTy_, buf_dtypes, llvm::ConstantInt::getSigned(IntTy_, i));
2434     irb_.CreateStore(
2435         llvm::ConstantInt::getSigned(ByteTy_, (int8_t)b->dtype().scalar_type()),
2436         gep);
2437 
2438     // Store rank of the buf
2439     // @lint-ignore CLANGTIDY
2440     gep = irb_.CreateInBoundsGEP(
2441         LongTy_, buf_ranks, llvm::ConstantInt::getSigned(IntTy_, i));
2442     irb_.CreateStore(
2443         llvm::ConstantInt::getSigned(LongTy_, b->dims().size()), gep);
2444 
2445     // Store dims of the buf
2446     for (const auto dim : c10::irange(b->dims().size())) {
2447       gep = irb_.CreateInBoundsGEP(
2448           LongTy_, buf_dims, llvm::ConstantInt::getSigned(IntTy_, dim_idx));
2449       b->dims()[dim]->accept(this);
2450       auto dim_val = this->value_;
2451       irb_.CreateStore(irb_.CreateZExt(dim_val, LongTy_), gep);
2452       dim_idx++;
2453     }
2454 
2455     // Store strides of the buf
2456     for (const auto dim : c10::irange(b->dims().size())) {
2457       gep = irb_.CreateInBoundsGEP(
2458           LongTy_,
2459           buf_strides,
2460           llvm::ConstantInt::getSigned(IntTy_, stride_idx));
2461       b->strides()[dim]->accept(this);
2462       auto stride_val = this->value_;
2463       irb_.CreateStore(irb_.CreateZExt(stride_val, LongTy_), gep);
2464       stride_idx++;
2465     }
2466 
2467     i++;
2468   }
2469 
2470   i = 0;
2471   for (const ExprPtr& arg : v->args()) {
2472     auto gep = irb_.CreateInBoundsGEP(
2473         LongTy_, extra_args, llvm::ConstantInt::getSigned(IntTy_, i));
2474     arg->accept(this);
2475     irb_.CreateStore(irb_.CreateZExtOrBitCast(this->value_, LongTy_), gep);
2476     i++;
2477   }
2478 
2479   // Generate the call itself
2480   std::string fname = v->func_name();
2481 
2482 #if LLVM_VERSION_MAJOR >= 15
2483   FunctionCallee callee = module_->getOrInsertFunction(
2484       fname,
2485       llvm::FunctionType::get(
2486           llvm::Type::getVoidTy(getContext()), // return type
2487           {LongTy_, // int64_t bufs_in_size
2488            OpqPtrTy_, // void** buf_data
2489            OpqPtrTy_, // int64_t* buf_ranks
2490            OpqPtrTy_, // int64_t* buf_dims
2491            OpqPtrTy_, // int64_t* buf_strides
2492            OpqPtrTy_, // int64_t* buf_dtypes
2493            LongTy_, // int64_t args_num
2494            OpqPtrTy_}, // int64_t* extra_args
2495           false)); // is var_arg
2496 #else
2497   FunctionCallee callee = module_->getOrInsertFunction(
2498       fname,
2499       llvm::FunctionType::get(
2500           llvm::Type::getVoidTy(getContext()), // return type
2501           {LongTy_, // int64_t bufs_in_size
2502            Int8PtrTy_->getPointerTo(), // void** buf_data
2503            LongTy_->getPointerTo(), // int64_t* buf_ranks
2504            LongTy_->getPointerTo(), // int64_t* buf_dims
2505            LongTy_->getPointerTo(), // int64_t* buf_strides
2506            ByteTy_->getPointerTo(), // int64_t* buf_dtypes
2507            LongTy_, // int64_t args_num
2508            LongTy_->getPointerTo()}, // int64_t* extra_args
2509           false)); // is var_arg
2510 #endif
2511 
2512   auto call_ty = callee.getFunctionType();
2513   auto call_fn = callee.getCallee();
2514   llvm::cast<llvm::Function>(call_fn)->addFnAttr(llvm::Attribute::NoUnwind);
2515 
2516   irb_.CreateCall(
2517       call_ty,
2518       call_fn,
2519       // @lint-ignore CLANGTIDY
2520       {llvm::ConstantInt::getSigned(LongTy_, bufs_in_size),
2521        buf_ptrs,
2522        buf_ranks,
2523        buf_dims,
2524        buf_strides,
2525        buf_dtypes,
2526        // @lint-ignore CLANGTIDY
2527        llvm::ConstantInt::getSigned(LongTy_, args_num),
2528        extra_args});
2529 
2530   // @lint-ignore CLANGTIDY
2531   for (const auto i : c10::irange(bufs_out_size)) {
2532     const auto& buf_out = bufs_out[i];
2533 #if LLVM_VERSION_MAJOR >= 15
2534     auto gep = irb_.CreateInBoundsGEP(
2535         OpqPtrTy_, buf_ptrs, llvm::ConstantInt::getSigned(IntTy_, i));
2536     llvm::Value* ptr = irb_.CreatePointerCast(
2537         irb_.CreateLoad(OpqPtrTy_, gep), dtypeToLLVMPtr(buf_out->dtype()));
2538 #else
2539     auto gep = irb_.CreateInBoundsGEP(
2540         Int8PtrTy_, buf_ptrs, llvm::ConstantInt::getSigned(IntTy_, i));
2541     llvm::Value* ptr = irb_.CreatePointerCast(
2542         irb_.CreateLoad(Int8PtrTy_, gep), dtypeToLLVMPtr(buf_out->dtype()));
2543 #endif
2544 
2545     varToVal_[buf_out->base_handle()] = ptr;
2546 
2547     for (auto it = bufsExtAllocReuse_.find(buf_out);
2548          it != bufsExtAllocReuse_.end();
2549          it++) {
2550       auto buf = it->second;
2551       handleBufReuse(buf, buf_out);
2552     }
2553     bufsExtAllocReuse_.erase(buf_out);
2554 
2555 #if LLVM_VERSION_MAJOR >= 15
2556     gep = irb_.CreateInBoundsGEP(
2557         OpqPtrTy_,
2558         buf_ptrs,
2559         // @lint-ignore CLANGTIDY
2560         llvm::ConstantInt::getSigned(IntTy_, bufs_out_size + bufs_in_size + i));
2561     bufsExtToFreeVal_[buf_out->base_handle()] = irb_.CreateLoad(OpqPtrTy_, gep);
2562 #else
2563     gep = irb_.CreateInBoundsGEP(
2564         Int8PtrTy_,
2565         buf_ptrs,
2566         // @lint-ignore CLANGTIDY
2567         llvm::ConstantInt::getSigned(IntTy_, bufs_out_size + bufs_in_size + i));
2568     bufsExtToFreeVal_[buf_out->base_handle()] =
2569         irb_.CreateLoad(Int8PtrTy_, gep);
2570 #endif
2571   }
2572 
2573   value_ = llvm::ConstantInt::get(IntTy_, 0);
2574 }
2575 
visit(const AllocatePtr & v)2576 void LLVMCodeGenImpl::visit(const AllocatePtr& v) {
2577   llvm::Value* size =
2578       llvm::ConstantInt::getSigned(LongTy_, v->dtype().byte_size());
2579   for (ExprPtr e : v->dims()) {
2580     e->accept(this);
2581     size = irb_.CreateMul(size, irb_.CreateZExt(value_, LongTy_));
2582   }
2583 
2584   value_ = llvm::ConstantInt::get(IntTy_, 0);
2585 
2586   if (llvm::ConstantInt* CI = llvm::dyn_cast<llvm::ConstantInt>(size)) {
2587     if (CI->getSExtValue() < 512) {
2588       llvm::Value* alloca = irb_.CreateAlloca(dtypeToLLVM(v->dtype()), size);
2589       varToVal_[v->buffer_var()] = alloca;
2590       return;
2591     }
2592   }
2593 
2594 #if LLVM_VERSION_MAJOR > 17
2595   llvm::Instruction* I = irb_.CreateMalloc(
2596       LongTy_, dtypeToLLVM(v->dtype()), size, nullptr, nullptr, "");
2597 #else
2598   llvm::Instruction* I = llvm::CallInst::CreateMalloc(
2599       irb_.GetInsertBlock(),
2600       LongTy_,
2601       dtypeToLLVM(v->dtype()),
2602       size,
2603       nullptr,
2604       nullptr);
2605 #endif
2606   // Insert the bitcast into the block.
2607   irb_.SetInsertPoint(irb_.GetInsertBlock());
2608   llvm::Value* malloc = irb_.Insert(I);
2609   varToVal_[v->buffer_var()] = malloc;
2610 }
2611 
visit(const PlacementAllocatePtr & v)2612 void LLVMCodeGenImpl::visit(const PlacementAllocatePtr& v) {
2613   auto buf_to_reuse = v->buf_to_reuse();
2614   auto buf = v->buf();
2615 
2616   if (bufsExtAlloc_.count(buf_to_reuse)) {
2617     bufsExtAllocReuse_.insert({buf_to_reuse, buf});
2618     return;
2619   }
2620 
2621   handleBufReuse(buf, buf_to_reuse);
2622 }
2623 
visit(const FreePtr & v)2624 void LLVMCodeGenImpl::visit(const FreePtr& v) {
2625   value_ = llvm::ConstantInt::get(IntTy_, 0);
2626 
2627   llvm::Value* ptr = bufsExtToFreeVal_.count(v->buffer_var())
2628       ? bufsExtToFreeVal_.at(v->buffer_var())
2629       : varToVal_.at(v->buffer_var());
2630 
2631   if (!llvm::isa<llvm::AllocaInst>(ptr)) {
2632 #if LLVM_VERSION_MAJOR > 17
2633     irb_.Insert(irb_.CreateFree(ptr));
2634 #else
2635     irb_.Insert(llvm::CallInst::CreateFree(ptr, irb_.GetInsertBlock()));
2636 #endif
2637   }
2638 }
2639 
visit(const FreeExtPtr & v)2640 void LLVMCodeGenImpl::visit(const FreeExtPtr& v) {
2641   value_ = llvm::ConstantInt::get(IntTy_, 0);
2642   const auto& bufs = v->bufs();
2643   const auto bufs_num = bufs.size();
2644 
2645 #if LLVM_VERSION_MAJOR >= 15
2646   llvm::Value* ptrs = irb_.CreateAlloca(
2647       OpqPtrTy_, llvm::ConstantInt::getSigned(IntTy_, bufs_num));
2648 #else
2649   llvm::Value* ptrs = irb_.CreateAlloca(
2650       Int8PtrTy_, llvm::ConstantInt::getSigned(IntTy_, bufs_num));
2651 #endif
2652 
2653   for (const auto i : c10::irange(bufs_num)) {
2654     const auto& buf = bufs[i];
2655 #if LLVM_VERSION_MAJOR >= 15
2656     llvm::Value* gep = irb_.CreateInBoundsGEP(
2657         OpqPtrTy_, ptrs, llvm::ConstantInt::getSigned(IntTy_, i));
2658 #else
2659     llvm::Value* gep = irb_.CreateInBoundsGEP(
2660         Int8PtrTy_, ptrs, llvm::ConstantInt::getSigned(IntTy_, i));
2661 #endif
2662 
2663     auto ptr = bufsExtToFreeVal_[buf->base_handle()];
2664     irb_.CreateStore(ptr, gep);
2665   }
2666 
2667 #if LLVM_VERSION_MAJOR >= 15
2668   FunctionCallee callee = module_->getOrInsertFunction(
2669       "nnc_aten_free",
2670       llvm::FunctionType::get(
2671           llvm::Type::getVoidTy(getContext()), // return type
2672           {
2673               LongTy_, // int64_t bufs_num
2674               OpqPtrTy_, // void** ptrs
2675           },
2676           false)); // is var_arg
2677 #else
2678   FunctionCallee callee = module_->getOrInsertFunction(
2679       "nnc_aten_free",
2680       llvm::FunctionType::get(
2681           llvm::Type::getVoidTy(getContext()), // return type
2682           {
2683               LongTy_, // int64_t bufs_num
2684               Int8PtrTy_->getPointerTo(), // void** ptrs
2685           },
2686           false)); // is var_arg
2687 #endif
2688 
2689   auto call_ty = callee.getFunctionType();
2690   auto call_fn = callee.getCallee();
2691   llvm::cast<llvm::Function>(call_fn)->addFnAttr(llvm::Attribute::NoUnwind);
2692 
2693   irb_.CreateCall(
2694       call_ty,
2695       call_fn,
2696       {llvm::ConstantInt::getSigned(LongTy_, bufs_num), ptrs});
2697 
2698   value_ = llvm::ConstantInt::get(IntTy_, 0);
2699 }
2700 
visit(const LetPtr & v)2701 void LLVMCodeGenImpl::visit(const LetPtr& v) {
2702   v->value()->accept(this);
2703   if (!varToVal_.count(v->var())) {
2704     varToVal_.emplace(v->var(), value_);
2705     scopeToVar_[scope_].push_back(v->var());
2706   } else {
2707     throw std::runtime_error("var should not exist before");
2708   }
2709 }
2710 
visit(const CondPtr & v)2711 void LLVMCodeGenImpl::visit(const CondPtr& v) {
2712   // Even if true_stmt and false_stmt are nullptr,
2713   // in case condition is a function call with side effect,
2714   // we still evaluate it.
2715   v->condition()->accept(this);
2716 
2717   if (!v->true_stmt() && !v->false_stmt()) {
2718     return;
2719   }
2720   assert(v->true_stmt());
2721 
2722   llvm::Value* condition = value_;
2723   llvm::Value* c = irb_.CreateICmpNE(
2724       condition, llvm::ConstantInt::get(condition->getType(), 0));
2725   llvm::BasicBlock* then_block =
2726       llvm::BasicBlock::Create(getContext(), "then", fn_);
2727   llvm::BasicBlock* else_block = nullptr;
2728   if (v->false_stmt()) {
2729     else_block = llvm::BasicBlock::Create(getContext(), "else", fn_);
2730   }
2731   llvm::BasicBlock* end_block =
2732       llvm::BasicBlock::Create(getContext(), "end", fn_);
2733 
2734   if (else_block) {
2735     irb_.CreateCondBr(c, then_block, else_block);
2736   } else {
2737     irb_.CreateCondBr(c, then_block, end_block);
2738   }
2739 
2740   irb_.SetInsertPoint(then_block);
2741   v->true_stmt()->accept(this);
2742   irb_.CreateBr(end_block);
2743 
2744   if (else_block) {
2745     irb_.SetInsertPoint(else_block);
2746     v->false_stmt()->accept(this);
2747     irb_.CreateBr(end_block);
2748   }
2749 
2750   irb_.SetInsertPoint(end_block);
2751 }
2752 
2753 // "New" PassManager needed to replace TM.adjustPassManager
2754 #if LLVM_VERSION_MAJOR >= 15
optimize(llvm::Module & M)2755 void LLVMCodeGenImpl::optimize(llvm::Module& M) {
2756   // Add internal analysis passes from the target machine.
2757   auto& TM = jit_->getTargetMachine();
2758 
2759   // Create the analysis managers.
2760   llvm::LoopAnalysisManager LAM;
2761   llvm::FunctionAnalysisManager FAM;
2762   llvm::CGSCCAnalysisManager CGAM;
2763   llvm::ModuleAnalysisManager MAM;
2764 
2765   // Create the new pass manager builder.
2766   // Take a look at the PassBuilder constructor parameters for more
2767   // customization, e.g. specifying a TargetMachine or various debugging
2768   // options.
2769   llvm::PassBuilder PB(&TM);
2770 
2771 #if LLVM_VERSION_MAJOR >= 18 && LLVM_VERSION_MAJOR < 19
2772   TM.registerPassBuilderCallbacks(PB, false);
2773 #else
2774   TM.registerPassBuilderCallbacks(PB);
2775 #endif
2776 
2777   // Register all the basic analyses with the managers.
2778   PB.registerModuleAnalyses(MAM);
2779   PB.registerCGSCCAnalyses(CGAM);
2780   PB.registerFunctionAnalyses(FAM);
2781   PB.registerLoopAnalyses(LAM);
2782   PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
2783 
2784   llvm::ModulePassManager MPM =
2785       PB.buildPerModuleDefaultPipeline(llvm::OptimizationLevel::O3);
2786   llvm::FunctionPassManager FPM = PB.buildFunctionSimplificationPipeline(
2787       llvm::OptimizationLevel::O3, llvm::ThinOrFullLTOPhase::None);
2788 
2789   FAM.registerPass([&] { return TM.getTargetIRAnalysis(); });
2790 
2791   FPM.addPass(llvm::LoopVectorizePass());
2792   FPM.addPass(llvm::SLPVectorizerPass());
2793 
2794   FPM.addPass(llvm::DCEPass());
2795   MPM.addPass(llvm::AlwaysInlinerPass());
2796 
2797   MPM.run(M, MAM);
2798   for (auto& FF : M) {
2799     if (!FF.empty()) {
2800       FPM.run(FF, FAM);
2801     }
2802   }
2803 }
2804 #else // "Old" PassManager
optimize(llvm::Module & M)2805 void LLVMCodeGenImpl::optimize(llvm::Module& M) {
2806   llvm::legacy::FunctionPassManager FPM(&M);
2807   llvm::legacy::PassManager PM;
2808 
2809   // Add internal analysis passes from the target machine.
2810   auto& TM = jit_->getTargetMachine();
2811   PM.add(llvm::createTargetTransformInfoWrapperPass(TM.getTargetIRAnalysis()));
2812   FPM.add(llvm::createTargetTransformInfoWrapperPass(TM.getTargetIRAnalysis()));
2813 
2814   llvm::PassManagerBuilder PMB;
2815   PMB.OptLevel = 3;
2816   PMB.LoopVectorize = true;
2817   PMB.SLPVectorize = true;
2818   TM.adjustPassManager(PMB);
2819 
2820   PMB.populateFunctionPassManager(FPM);
2821   PMB.populateModulePassManager(PM);
2822   FPM.doInitialization();
2823   PM.add(llvm::createDeadCodeEliminationPass());
2824   PM.add(llvm::createAlwaysInlinerLegacyPass());
2825   PM.run(M);
2826   for (auto& FF : M) {
2827     FPM.run(FF);
2828   }
2829   FPM.doFinalization();
2830 }
2831 #endif
2832 
2833 RegisterCodeGen<LLVMCodeGen> llvm_codegen_reg("llvm_codegen");
2834 
2835 #endif // TORCH_ENABLE_LLVM
2836