1 #ifdef TORCH_ENABLE_LLVM
2
3 #include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
4
5 #include <ATen/NativeFunctions.h>
6 #include <ATen/Parallel.h>
7 #include <c10/util/Exception.h>
8 #include <c10/util/irange.h>
9 #include <torch/csrc/jit/tensorexpr/analysis.h>
10 #include <torch/csrc/jit/tensorexpr/llvm_jit.h>
11
12 // Note [llvm::SCEVPredicate non-virtual destructor]
13 // llvm::SCEVPredicate has virtual function but non-virtual destructor
14 // https://github.com/llvm/llvm-project/blob/c1a0a213378a458fbea1a5c77b315c7dce08fd05/llvm/include/llvm/Analysis/ScalarEvolution.h#L198
15 #pragma GCC diagnostic push
16 #pragma GCC diagnostic ignored "-Wnon-virtual-dtor"
17 #include <llvm/Analysis/TargetTransformInfo.h>
18 #pragma GCC diagnostic pop
19
20 #include <llvm/Analysis/CGSCCPassManager.h>
21 #include <llvm/Analysis/LoopAnalysisManager.h>
22 #include <llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h>
23 #include <llvm/ExecutionEngine/Orc/ThreadSafeModule.h>
24 #include <llvm/IR/IRBuilder.h>
25 #include <llvm/IR/LegacyPassManager.h>
26 #include <llvm/IR/MDBuilder.h>
27 #include <llvm/IR/PassManager.h>
28 #include <llvm/IR/Verifier.h>
29 #include <llvm/MC/MCSubtargetInfo.h>
30 #include <llvm/Pass.h>
31
32 // see Note [llvm::SCEVPredicate non-virtual destructor]
33 #pragma GCC diagnostic push
34 #pragma GCC diagnostic ignored "-Wnon-virtual-dtor"
35 #include <llvm/Passes/PassBuilder.h>
36 #pragma GCC diagnostic pop
37
38 #if LLVM_VERSION_MAJOR >= 18
39 #include <llvm/TargetParser/Host.h>
40 #else
41 #include <llvm/Support/Host.h>
42 #endif
43 #include <llvm/Support/TargetSelect.h>
44 #include <llvm/Transforms/IPO/AlwaysInliner.h>
45 #include <llvm/Transforms/Scalar/DCE.h>
46 #include <llvm/Transforms/Vectorize/LoopVectorize.h>
47 #include <llvm/Transforms/Vectorize/SLPVectorizer.h>
48
49 #if LLVM_VERSION_MAJOR >= 10
50 #include <llvm/Support/CodeGen.h>
51 #else
52 #include <llvm/Target/TargetMachine.h>
53 #endif
54
55 #if LLVM_VERSION_MAJOR >= 11
56 #include <llvm/Support/TypeSize.h>
57 #endif
58
59 #if LLVM_VERSION_MAJOR < 15
60 #include <llvm/Transforms/IPO/PassManagerBuilder.h>
61 #endif
62
63 #include <llvm/Transforms/IPO/AlwaysInliner.h>
64 #include <llvm/Transforms/Scalar.h>
65
66 #include <torch/csrc/jit/tensorexpr/expr.h>
67 #include <torch/csrc/jit/tensorexpr/external_functions_registry.h>
68 #include <torch/csrc/jit/tensorexpr/half_support.h>
69 #include <torch/csrc/jit/tensorexpr/ir.h>
70 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
71 #include <torch/csrc/jit/tensorexpr/tensor.h>
72 #include <torch/csrc/jit/tensorexpr/types.h>
73
74 #include <torch/csrc/jit/jit_log.h>
75
76 #include <memory>
77
78 using namespace torch::jit::tensorexpr;
79
80 C10_DEFINE_bool(
81 torch_jit_llvm_use_fast_intrinsics,
82 false,
83 "Use fast (but slightly less accurate) implementations of tanh and sigmoid");
84
85 namespace torch::jit::tensorexpr {
86
LLVMTargetTriple()87 std::optional<std::string>& LLVMTargetTriple() {
88 static std::optional<std::string> triple = std::nullopt;
89 return triple;
90 }
LLVMTargetCPU()91 std::optional<std::string>& LLVMTargetCPU() {
92 static std::optional<std::string> cpu = std::nullopt;
93 return cpu;
94 }
LLVMTargetAttrs()95 std::optional<std::string>& LLVMTargetAttrs() {
96 static std::optional<std::string> attrs = std::nullopt;
97 return attrs;
98 }
LLVMAOTWorkflow()99 bool& LLVMAOTWorkflow() {
100 static bool aot_workflow = false;
101 return aot_workflow;
102 }
103
104 namespace {
105
106 #if LLVM_VERSION_MAJOR >= 15
107 // Address and type pair to assist in handling of opaque pointers.
108 struct TypedPointer {
109 TypedPointer() = default;
TypedPointertorch::jit::tensorexpr::__anonc5f37dc70111::TypedPointer110 TypedPointer(llvm::Type* t, llvm::Value* a) : type(t), addr(a) {}
111 llvm::Type* type = nullptr;
112 llvm::Value* addr = nullptr;
113 };
114 #endif
115
llvm_comparison_predicate(CompareSelectOperation compare_op,const ScalarType & type)116 llvm::CmpInst::Predicate llvm_comparison_predicate(
117 CompareSelectOperation compare_op,
118 const ScalarType& type) {
119 switch (compare_op) {
120 case CompareSelectOperation::kEQ:
121 return llvm::ICmpInst::ICMP_EQ;
122 case CompareSelectOperation::kNE:
123 return llvm::ICmpInst::ICMP_NE;
124 case CompareSelectOperation::kGT:
125 return c10::isSignedType(type) ? llvm::ICmpInst::ICMP_SGT
126 : llvm::ICmpInst::ICMP_UGT;
127 case CompareSelectOperation::kGE:
128 return c10::isSignedType(type) ? llvm::ICmpInst::ICMP_SGE
129 : llvm::ICmpInst::ICMP_UGE;
130 case CompareSelectOperation::kLT:
131 return c10::isSignedType(type) ? llvm::ICmpInst::ICMP_SLT
132 : llvm::ICmpInst::ICMP_ULT;
133 case CompareSelectOperation::kLE:
134 return c10::isSignedType(type) ? llvm::ICmpInst::ICMP_SLE
135 : llvm::ICmpInst::ICMP_ULE;
136 default:
137 // TODO: change to a proper error report
138 throw std::runtime_error("invalid operator type");
139 }
140 }
141
llvm_fp_comparison_predicate(CompareSelectOperation compare_op)142 llvm::CmpInst::Predicate llvm_fp_comparison_predicate(
143 CompareSelectOperation compare_op) {
144 switch (compare_op) {
145 case CompareSelectOperation::kEQ:
146 return llvm::FCmpInst::FCMP_OEQ;
147 case CompareSelectOperation::kNE:
148 return llvm::FCmpInst::FCMP_ONE;
149 case CompareSelectOperation::kGT:
150 return llvm::FCmpInst::FCMP_OGT;
151 case CompareSelectOperation::kGE:
152 return llvm::FCmpInst::FCMP_OGE;
153 case CompareSelectOperation::kLT:
154 return llvm::FCmpInst::FCMP_OLT;
155 case CompareSelectOperation::kLE:
156 return llvm::FCmpInst::FCMP_OLE;
157 default:
158 // TODO: change to a proper error report
159 throw std::runtime_error("invalid operator type");
160 }
161 }
162
163 #if LLVM_VERSION_MAJOR <= 9
ElementCount(int lanes)164 int ElementCount(int lanes) {
165 return lanes;
166 }
167 #else
ElementCount(int lanes)168 llvm::ElementCount ElementCount(int lanes) {
169 #if LLVM_VERSION_MAJOR <= 11
170 return llvm::ElementCount(static_cast<unsigned>(lanes), false);
171 #elif LLVM_VERSION_MAJOR >= 12
172 return llvm::ElementCount::getFixed(lanes);
173 #else
174 #error Only LLVM versions 8 and above are supported.
175 #endif
176 }
177 #endif
178
179 #if LLVM_VERSION_MAJOR >= 9
180
181 using FunctionCallee = llvm::FunctionCallee;
182
183 #elif LLVM_VERSION_MAJOR == 8 && LLVM_VERSION_PATCH == 20181009
184
185 struct FunctionCallee {
FunctionCalleetorch::jit::tensorexpr::__anonc5f37dc70111::FunctionCallee186 FunctionCallee() {}
187
FunctionCalleetorch::jit::tensorexpr::__anonc5f37dc70111::FunctionCallee188 FunctionCallee(llvm::Constant* fn)
189 : v_(fn), ft_(cast<llvm::Function>(v_)->getFunctionType()) {}
190
getFunctionTypetorch::jit::tensorexpr::__anonc5f37dc70111::FunctionCallee191 llvm::FunctionType* getFunctionType() {
192 return ft_;
193 }
194
getCalleetorch::jit::tensorexpr::__anonc5f37dc70111::FunctionCallee195 llvm::Value* getCallee() {
196 return v_;
197 }
198
199 private:
200 llvm::Value* v_{nullptr};
201 llvm::FunctionType* ft_{nullptr};
202 };
203
204 #else
205 #error Only LLVM versions 8 and above are supported.
206 #endif
207 } // namespace
208
209 class LLVMCodeGenCallee {
210 public:
LLVMCodeGenCallee(std::unique_ptr<llvm::orc::PytorchLLVMJIT> jit,void * kernelAddress)211 LLVMCodeGenCallee(
212 std::unique_ptr<llvm::orc::PytorchLLVMJIT> jit,
213 void* kernelAddress)
214 : jit_(std::move(jit)), kernelAddress_(kernelAddress) {}
215
getJIT()216 llvm::orc::PytorchLLVMJIT* getJIT() {
217 return jit_.get();
218 }
219
getKernelAddress()220 void* getKernelAddress() {
221 return kernelAddress_;
222 }
223
setKernelAddress(void * kernelAddress)224 void setKernelAddress(void* kernelAddress) {
225 kernelAddress_ = kernelAddress;
226 }
227
228 private:
229 std::unique_ptr<llvm::orc::PytorchLLVMJIT> jit_;
230 void* kernelAddress_;
231 };
232
233 class LLVMCodeGenImpl : public IRVisitor {
234 private:
235 std::unique_ptr<llvm::LLVMContext> context_;
236 llvm::IRBuilder<> irb_;
237 std::unique_ptr<llvm::orc::PytorchLLVMJIT> jit_;
238 std::unique_ptr<llvm::Module> module_;
239 llvm::Function* fn_;
240 llvm::BasicBlock* bb_;
241 llvm::Value* value_{nullptr};
242 llvm::JITTargetAddress kernelAddress_;
243 std::string kernel_func_name_;
244
245 #define LLVM_TYPE_DECLARE(_1, Name) llvm::Type* Name##Ty_;
246 AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, LLVM_TYPE_DECLARE);
247 #undef LLVM_TYPE_DECLARE
248
249 #if LLVM_VERSION_MAJOR >= 15
250 llvm::Type* OpqPtrTy_;
251 #else
252 llvm::Type* Int8PtrTy_;
253 #endif
254 llvm::Type* VoidTy_;
255 std::unordered_map<VarPtr, int> varToArg_;
256 std::unordered_map<VarPtr, llvm::Value*> varToVal_;
257 std::unordered_set<BufPtr> bufsExtAlloc_;
258 std::unordered_map<VarPtr, llvm::Value*> bufsExtToFreeVal_;
259 std::unordered_multimap<BufPtr, BufPtr> bufsExtAllocReuse_;
260 std::unordered_map<BlockPtr, std::vector<VarPtr>> scopeToVar_;
261 BlockPtr scope_;
262
263 std::string llvmCode_;
264 std::string asmCode_;
265
266 private:
267 llvm::LLVMContext& getContext();
268 llvm::Type* dtypeToLLVM(Dtype dtype);
269 llvm::Type* dtypeToLLVMPtr(Dtype dtype);
270 void emitWrapper(const std::vector<llvm::Type*>& params);
271 void emitKernel(StmtPtr stmt, const std::vector<llvm::Type*>& params);
272 llvm::Value* toVec(llvm::Value* v, int lanes);
273
274 enum Arity {
275 Unary,
276 Binary,
277 };
278
279 using SimdCallee = std::tuple<llvm::FunctionType*, llvm::Value*, bool>;
280 SimdCallee getSimdFunction(
281 const std::string& name,
282 llvm::Type* type,
283 Arity arity,
284 int lanes);
285
286 llvm::Value* varToValue(VarPtr var);
287 void replaceVarMapping(
288 const std::vector<VarPtr>& vars,
289 const std::vector<llvm::Value*>& vals);
290
291 #if LLVM_VERSION_MAJOR >= 15
292 TypedPointer packFuncArgs(const std::vector<llvm::Value*>& func_args);
293 std::vector<llvm::Value*> unpackFuncArgs(TypedPointer packed, int arg_count);
294 #else
295 llvm::Value* packFuncArgs(const std::vector<llvm::Value*>& func_args);
296 std::vector<llvm::Value*> unpackFuncArgs(llvm::Value* packed, int arg_count);
297 #endif
298
299 void processParallelFor(ForPtr v);
300 void handleBufReuse(BufPtr buf, BufPtr buf_to_reuse);
301
302 public:
303 LLVMCodeGenImpl(
304 StmtPtr stmt,
305 const std::vector<CodeGen::BufferArg>& args,
306 at::Device device,
307 Dtype dtype,
308 std::string kernel_func_name,
309 std::optional<std::string> triple,
310 std::optional<std::string> cpu,
311 std::optional<std::string> attrs);
312 ~LLVMCodeGenImpl() override = default;
313
314 llvm::JITTargetAddress getKernelAddress() const;
315 std::unique_ptr<llvm::orc::PytorchLLVMJIT> releaseJIT();
316
317 void visit(const AddPtr& v) override;
318 void visit(const SubPtr& v) override;
319 void visit(const MulPtr& v) override;
320 void visit(const DivPtr& v) override;
321 void visit(const ModPtr& v) override;
322 void visit(const MaxPtr& v) override;
323 void visit(const MinPtr& v) override;
324 void visit(const AndPtr& v) override;
325 void visit(const OrPtr& v) override;
326 void visit(const XorPtr& v) override;
327 void visit(const LshiftPtr& v) override;
328 void visit(const RshiftPtr& v) override;
329 void visit(const CompareSelectPtr& v) override;
330
331 #define IMM_VISIT_DECLARE(_1, Name) void visit(const Name##ImmPtr& v) override;
332 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT_DECLARE);
333 #undef IMM_VISIT_DECLARE
334
335 void visit(const CastPtr& v) override;
336 void visit(const BitCastPtr& v) override;
337 void visit(const VarPtr& v) override;
338 void visit(const RampPtr& v) override;
339 void visit(const LoadPtr& v) override;
340 void visit(const ForPtr& v) override;
341 void visit(const BlockPtr& v) override;
342 void visit(const StorePtr& v) override;
343 void visit(const BroadcastPtr& v) override;
344 void visit(const IfThenElsePtr& v) override;
345 void visit(const IntrinsicsPtr& v) override;
346 void visit(const AllocatePtr& v) override;
347 void visit(const FreePtr& v) override;
348 void visit(const FreeExtPtr& v) override;
349 void visit(const PlacementAllocatePtr& v) override;
350 void visit(const LetPtr& v) override;
351 void visit(const CondPtr& v) override;
352 void visit(const ExternalCallPtr& v) override;
353 void visit(const ExternalCallWithAllocPtr& v) override;
354
355 void emitIsNan(IntrinsicsPtr v);
356
357 llvm::Value* emitUnmaskedLoad(
358 llvm::Type* ty,
359 llvm::Value* addr,
360 llvm::Value* idx);
361 llvm::Value* emitMaskedLoad(
362 llvm::Type* ty,
363 llvm::Value* addr,
364 llvm::Value* idx,
365 llvm::Value* mask);
366 void emitUnmaskedStore(
367 llvm::Type* ty,
368 llvm::Value* base,
369 llvm::Value* idx,
370 llvm::Value* val);
371 void emitMaskedStore(
372 llvm::Type* ty,
373 llvm::Value* base,
374 llvm::Value* idx,
375 llvm::Value* mask,
376 llvm::Value* val);
377
378 void optimize(llvm::Module& M);
getLLVMCodeText()379 std::string getLLVMCodeText() {
380 return llvmCode_;
381 }
getASMCodeText()382 std::string getASMCodeText() {
383 return asmCode_;
384 }
385 };
386
387 } // namespace torch::jit::tensorexpr
388
389 LLVMCodeGen::~LLVMCodeGen() = default;
390
LLVMCodeGen(StmtPtr stmt)391 LLVMCodeGen::LLVMCodeGen(StmtPtr stmt)
392 : LLVMCodeGen(stmt, std::vector<CodeGen::BufferArg>()) {}
393
LLVMCodeGen(StmtPtr stmt,const std::vector<BufferArg> & args,at::Device device,const std::string & kernel_func_name,Dtype dtype,std::optional<std::string> triple,std::optional<std::string> cpu,std::optional<std::string> attrs)394 LLVMCodeGen::LLVMCodeGen(
395 StmtPtr stmt,
396 const std::vector<BufferArg>& args,
397 at::Device device,
398 const std::string& kernel_func_name,
399 Dtype dtype,
400 std::optional<std::string> triple,
401 std::optional<std::string> cpu,
402 std::optional<std::string> attrs)
403 : CodeGen(stmt, args, device, kernel_func_name) {
404 impl_ = std::make_unique<LLVMCodeGenImpl>(
405 this->stmt(),
406 args,
407 device,
408 dtype,
409 this->kernel_func_name(),
410 triple,
411 cpu,
412 attrs);
413 callee_ = std::make_unique<LLVMCodeGenCallee>(
414 impl_->releaseJIT(), (void*)impl_->getKernelAddress());
415 }
416
cleanup_memory()417 void LLVMCodeGen::cleanup_memory() {
418 impl_.reset(nullptr);
419 }
420
call_raw(const std::vector<void * > & args)421 void LLVMCodeGen::call_raw(const std::vector<void*>& args) {
422 value<float>(const_cast<void**>(args.data()));
423 }
424
call_with_numel(void ** args,int64_t)425 void LLVMCodeGen::call_with_numel(void** args, int64_t /* numel */) {
426 value<float>(const_cast<void**>(args));
427 }
428
call(const std::vector<CallArg> & args)429 void LLVMCodeGen::call(const std::vector<CallArg>& args) {
430 auto& buf_args = buffer_args();
431 if (args.size() != buf_args.size()) {
432 throw malformed_input("wrong number of args in call");
433 }
434
435 constexpr unsigned nargs = 8;
436 c10::SmallVector<void*, nargs> argv;
437 argv.resize(buf_args.size());
438 for (size_t i = 0, e = buf_args.size(); i < e; i++) {
439 auto const& bufferArg = buf_args[i];
440 auto const& callArg = args[i];
441 argv[i] = argToPtr(bufferArg, callArg);
442 }
443 value<float>(argv.data());
444 }
445
empty_strided(c10::IntArrayRef size,c10::IntArrayRef stride,std::optional<c10::ScalarType> dtype_opt,std::optional<c10::Layout> layout_opt,std::optional<c10::Device> device_opt,std::optional<bool> pin_memory_opt)446 at::Tensor LLVMCodeGen::empty_strided(
447 c10::IntArrayRef size,
448 c10::IntArrayRef stride,
449 std::optional<c10::ScalarType> dtype_opt,
450 std::optional<c10::Layout> layout_opt,
451 std::optional<c10::Device> device_opt,
452 std::optional<bool> pin_memory_opt) {
453 return at::native::empty_strided_cpu(
454 size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
455 }
456
getKernelAddress(LLVMCodeGenCallee * callee)457 void* LLVMCodeGen::getKernelAddress(LLVMCodeGenCallee* callee) {
458 return (void*)callee->getKernelAddress();
459 }
460
getCodeText(const std::string & attr)461 std::string LLVMCodeGen::getCodeText(const std::string& attr /*=""*/) {
462 TORCH_INTERNAL_ASSERT(
463 impl_.get(),
464 "LLVMCodeGen memory has been cleaned up. So, code text is not available at this point");
465 if (attr == "asm") {
466 return impl_->getASMCodeText();
467 } else {
468 return impl_->getLLVMCodeText();
469 }
470 }
471
getKernelAddress() const472 llvm::JITTargetAddress LLVMCodeGenImpl::getKernelAddress() const {
473 return kernelAddress_;
474 }
475
releaseJIT()476 std::unique_ptr<llvm::orc::PytorchLLVMJIT> LLVMCodeGenImpl::releaseJIT() {
477 return std::move(jit_);
478 }
479
480 namespace {
481 // Global mutex to protect LLVM initialization. TargetRegistry::lookupTarget
482 // in particular is not thread-safe.
483 static std::mutex llvmInitMutex;
484 } // namespace
485
LLVMCodeGenImpl(StmtPtr stmt,const std::vector<CodeGen::BufferArg> & args,at::Device device,Dtype dtype,std::string kernel_func_name,std::optional<std::string> triple,std::optional<std::string> cpu,std::optional<std::string> attrs)486 LLVMCodeGenImpl::LLVMCodeGenImpl(
487 StmtPtr stmt,
488 const std::vector<CodeGen::BufferArg>& args,
489 at::Device device,
490 Dtype dtype,
491 std::string kernel_func_name,
492 std::optional<std::string> triple,
493 std::optional<std::string> cpu,
494 std::optional<std::string> attrs)
495 : context_(std::make_unique<llvm::LLVMContext>()),
496 irb_(getContext()),
497 kernel_func_name_(std::move(kernel_func_name)),
498 bufsExtAlloc_(ExternalAllocBufFinder::find(stmt)) {
499 if (!triple) {
500 triple = LLVMTargetTriple();
501 }
502 if (!cpu) {
503 cpu = LLVMTargetCPU();
504 }
505 if (!attrs) {
506 attrs = LLVMTargetAttrs();
507 }
508 // Manually map types to LLVM types.
509 ByteTy_ = llvm::Type::getInt8Ty(getContext());
510 CharTy_ = llvm::Type::getInt8Ty(getContext());
511 ShortTy_ = llvm::Type::getInt16Ty(getContext());
512 IntTy_ = llvm::Type::getInt32Ty(getContext());
513 LongTy_ = llvm::Type::getInt64Ty(getContext());
514 HalfTy_ = llvm::Type::getHalfTy(getContext());
515 FloatTy_ = llvm::Type::getFloatTy(getContext());
516 DoubleTy_ = llvm::Type::getDoubleTy(getContext());
517 VoidTy_ = llvm::Type::getVoidTy(getContext());
518 BoolTy_ = ByteTy_;
519 #if LLVM_VERSION_MAJOR >= 15
520 OpqPtrTy_ = llvm::PointerType::getUnqual(getContext());
521 #else
522 Int8PtrTy_ = llvm::Type::getInt8PtrTy(getContext());
523 #endif
524
525 {
526 std::lock_guard<std::mutex> g(llvmInitMutex);
527 llvm::InitializeAllTargets();
528 llvm::InitializeAllTargetMCs();
529 llvm::InitializeAllAsmPrinters();
530 jit_ = std::make_unique<llvm::orc::PytorchLLVMJIT>(triple, cpu, attrs);
531 }
532
533 module_ = std::make_unique<llvm::Module>("pytorch", getContext());
534 module_->setDataLayout(jit_->getDataLayout());
535 module_->setTargetTriple(jit_->getTargetMachine().getTargetTriple().str());
536
537 // We support float16 ops by casting expr inputs to float32
538 // and then casting the result back to float16
539
540 GRAPH_DEBUG("Before HalfRewriter ", *stmt);
541 HalfRewriter hsFix;
542 stmt = stmt->accept_mutator(&hsFix);
543 GRAPH_DEBUG("After HalfRewriter ", *stmt);
544
545 // Emit prototype and bind argument Vars to parameter indices.
546 llvm::Type* retTy = dtypeToLLVM(dtype);
547 std::vector<llvm::Type*> params;
548 for (const auto i : c10::irange(args.size())) {
549 auto const& arg = args[i];
550 if (arg.isVar()) {
551 params.push_back(dtypeToLLVM(arg.dtype()));
552 } else {
553 params.push_back(dtypeToLLVMPtr(arg.dtype()));
554 }
555 varToArg_[arg.var()] = i;
556 }
557 llvm::FunctionType* fntype = llvm::FunctionType::get(retTy, params, false);
558 fn_ = llvm::Function::Create(
559 fntype, llvm::Function::PrivateLinkage, "pytorch", module_.get());
560 fn_->addFnAttr(llvm::Attribute::AlwaysInline);
561 for (const auto i : c10::irange(args.size())) {
562 if (!args[i].isVar()) {
563 fn_->addParamAttr(i, llvm::Attribute::NoAlias);
564 }
565 }
566
567 emitWrapper(params);
568 emitKernel(stmt, params);
569
570 jit_->addModule(std::move(module_), std::move(context_));
571 if (!LLVMAOTWorkflow()) {
572 auto sym = jit_->findSymbol(kernel_func_name_);
573 kernelAddress_ = assertSuccess(sym.getAddress());
574 }
575 }
576
getContext()577 llvm::LLVMContext& LLVMCodeGenImpl::getContext() {
578 return *context_;
579 }
580
dtypeToLLVM(Dtype dtype)581 llvm::Type* LLVMCodeGenImpl::dtypeToLLVM(Dtype dtype) {
582 switch (dtype.scalar_type()) {
583 #define TYPE_CASE(_1, n) \
584 case ScalarType::n: \
585 return n##Ty_; \
586 break;
587
588 AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
589 #undef TYPE_CASE
590 case ScalarType::QInt8:
591 return CharTy_;
592 break;
593
594 case ScalarType::QUInt8:
595 return ByteTy_;
596 break;
597
598 case ScalarType::BFloat16:
599 return ShortTy_;
600 break;
601
602 default:
603 throw unsupported_dtype();
604 }
605 return nullptr;
606 }
607
dtypeToLLVMPtr(Dtype dtype)608 llvm::Type* LLVMCodeGenImpl::dtypeToLLVMPtr(Dtype dtype) {
609 return dtypeToLLVM(dtype)->getPointerTo();
610 }
611
emitWrapper(const std::vector<llvm::Type * > & params)612 void LLVMCodeGenImpl::emitWrapper(const std::vector<llvm::Type*>& params) {
613 #if LLVM_VERSION_MAJOR >= 15
614 auto wrapper = llvm::Function::Create(
615 llvm::FunctionType::get(IntTy_, {OpqPtrTy_}, false),
616 llvm::Function::ExternalLinkage,
617 kernel_func_name_,
618 module_.get());
619 #else
620 auto voidPtrTy = llvm::Type::getInt8PtrTy(getContext());
621 auto voidPtrPtrTy = voidPtrTy->getPointerTo();
622 auto wrapper = llvm::Function::Create(
623 llvm::FunctionType::get(IntTy_, {voidPtrPtrTy}, false),
624 llvm::Function::ExternalLinkage,
625 kernel_func_name_,
626 module_.get());
627 #endif
628
629 {
630 // Work around UBSAN crashes which reads 8 byte in front of every function.
631 // Otherwise, if the function was placed at the beginning of a page, reading
632 // 8B before the page could trigger a wild-addr-read ASAN failure if the
633 // page before this function was not mapped.
634 // - https://reviews.llvm.org/D148665
635 // - https://github.com/llvm/llvm-project/issues/65253
636 // Place the variable just before the function,
637 // the optimizer might otherwise disable this workaround.
638 // https://llvm.org/docs/LangRef.html#prefix-data
639 wrapper->setPrefixData(llvm::Constant::getNullValue(
640 llvm::ArrayType::get(llvm::Type::getInt8Ty(getContext()), 8)));
641 }
642
643 auto wrapBB = llvm::BasicBlock::Create(getContext(), "wrapBB", wrapper);
644 irb_.SetInsertPoint(wrapBB);
645 llvm::SmallVector<llvm::Value*, 6> wrappedArgs;
646 for (const auto i : c10::irange(params.size())) {
647 #if LLVM_VERSION_MAJOR >= 15
648 auto argp = irb_.CreateGEP(
649 OpqPtrTy_,
650 wrapper->arg_begin(),
651 llvm::ConstantInt::getSigned(IntTy_, i));
652 if (params[i]->isPointerTy()) {
653 auto arg =
654 irb_.CreatePointerCast(irb_.CreateLoad(OpqPtrTy_, argp), params[i]);
655 wrappedArgs.push_back(arg);
656 } else {
657 auto p =
658 irb_.CreatePointerCast(irb_.CreateLoad(OpqPtrTy_, argp), OpqPtrTy_);
659 auto arg = irb_.CreateLoad(params[i], p);
660 wrappedArgs.push_back(arg);
661 }
662 #else
663 auto argp = irb_.CreateGEP(
664 voidPtrTy,
665 wrapper->arg_begin(),
666 llvm::ConstantInt::getSigned(IntTy_, i));
667 if (params[i]->isPointerTy()) {
668 auto arg = irb_.CreatePointerCast(
669 irb_.CreateLoad(argp->getType()->getPointerElementType(), argp),
670 params[i]);
671 wrappedArgs.push_back(arg);
672 } else {
673 auto p = irb_.CreatePointerCast(
674 irb_.CreateLoad(argp->getType()->getPointerElementType(), argp),
675 params[i]->getPointerTo());
676 auto arg = irb_.CreateLoad(p->getType()->getPointerElementType(), p);
677 wrappedArgs.push_back(arg);
678 }
679 #endif
680 }
681 auto cc = irb_.CreateCall(fn_, wrappedArgs);
682 irb_.CreateRet(cc);
683 }
684
685 class LLVMIntrinsicsExpander : public GenericIntrinsicsExpander {
686 private:
mutate(const IntrinsicsPtr & v)687 ExprPtr mutate(const IntrinsicsPtr& v) override {
688 if (v->op_type() == kTanh) {
689 ScalarType stype = v->dtype().scalar_type();
690 if (stype == ScalarType::Float) {
691 return fast_tanh(ExprHandle(v->param(0)->accept_mutator(this))).node();
692 }
693 } else if (v->op_type() == kSigmoid) {
694 ScalarType stype = v->dtype().scalar_type();
695 if (stype == ScalarType::Float) {
696 return fast_sigmoid(ExprHandle(v->param(0)->accept_mutator(this)))
697 .node();
698 }
699 }
700 // TODO: fast exp
701 // TODO: fast erf
702 // TODO: fast sigmoid
703 return GenericIntrinsicsExpander::mutate(v);
704 }
705 };
706
emitKernel(StmtPtr stmt,const std::vector<llvm::Type * > & params)707 void LLVMCodeGenImpl::emitKernel(
708 StmtPtr stmt,
709 const std::vector<llvm::Type*>& params) {
710 // Set insert point to the real function.
711 bb_ = llvm::BasicBlock::Create(getContext(), "entry", fn_);
712 irb_.SetInsertPoint(bb_);
713
714 // Maybe expand some of the intrinsics.
715 if (FLAGS_torch_jit_llvm_use_fast_intrinsics) {
716 LLVMIntrinsicsExpander intrinsics_expander;
717 stmt = stmt->accept_mutator(&intrinsics_expander);
718 } else {
719 GenericIntrinsicsExpander intrinsics_expander;
720 stmt = stmt->accept_mutator(&intrinsics_expander);
721 }
722
723 // Compile the kernel.
724 stmt->accept(this);
725
726 // If the kernel is empty, set a default return value.
727 if (value_ == nullptr) {
728 value_ = llvm::ConstantInt::get(IntTy_, 0);
729 }
730
731 irb_.CreateRet(value_);
732
733 // print graph debug info before optimization
734 llvm::SmallVector<char, 0> asmBuffer;
735 llvm::raw_svector_ostream asmStream(asmBuffer);
736 if (GRAPH_DEBUG_ENABLED) {
737 module_->print(asmStream, nullptr);
738 }
739 GRAPH_DEBUG(
740 "\nLLVM module before optimizations\n\n", asmStream.str().str(), "\n");
741
742 if (llvm::verifyFunction(*fn_, &llvm::outs())) {
743 throw std::runtime_error("Function verification failed");
744 }
745
746 optimize(*module_);
747
748 asmBuffer.clear();
749 module_->print(asmStream, nullptr);
750 llvmCode_ = asmStream.str().str();
751 GRAPH_DEBUG(
752 "\nLLVM module after optimizations\n\n", asmStream.str().str(), "\n");
753
754 // print graph debug info after optimization
755 asmBuffer.clear();
756 llvm::legacy::PassManager PM;
757 jit_->getTargetMachine().addPassesToEmitFile(
758 PM,
759 asmStream,
760 nullptr,
761 #if LLVM_VERSION_MAJOR >= 18
762 llvm::CodeGenFileType::AssemblyFile);
763 #elif LLVM_VERSION_MAJOR >= 10
764 llvm::CodeGenFileType::CGFT_AssemblyFile);
765 #else
766 llvm::TargetMachine::CodeGenFileType::CGFT_AssemblyFile);
767 #endif
768 PM.run(*module_);
769 asmCode_ = asmStream.str().str();
770
771 GRAPH_DEBUG("\nLLVM generated assembly code\n\n", asmCode_, "\n");
772 }
773
774 // TODO: The binary ops are copypasta.
775
visit(const AddPtr & v)776 void LLVMCodeGenImpl::visit(const AddPtr& v) {
777 v->lhs()->accept(this);
778 auto lhs = this->value_;
779 bool lfp = lhs->getType()->isFPOrFPVectorTy();
780 v->rhs()->accept(this);
781 auto rhs = this->value_;
782 bool rfp = rhs->getType()->isFPOrFPVectorTy();
783
784 // TODO: Handle arg promotion.
785 if (lfp && rfp) {
786 value_ = irb_.CreateFAdd(lhs, rhs);
787 } else if (!lfp && !rfp) {
788 value_ = irb_.CreateAdd(lhs, rhs);
789 } else {
790 throw malformed_input("llvm_codegen: bad type in Add", v);
791 }
792 }
793
visit(const SubPtr & v)794 void LLVMCodeGenImpl::visit(const SubPtr& v) {
795 v->lhs()->accept(this);
796 auto lhs = this->value_;
797 bool lfp = lhs->getType()->isFPOrFPVectorTy();
798 v->rhs()->accept(this);
799 auto rhs = this->value_;
800 bool rfp = rhs->getType()->isFPOrFPVectorTy();
801
802 // TODO: Handle arg promotion.
803 if (lfp && rfp) {
804 value_ = irb_.CreateFSub(lhs, rhs);
805 } else if (!lfp && !rfp) {
806 value_ = irb_.CreateSub(lhs, rhs);
807 } else {
808 throw malformed_input("llvm_codegen: bad type in Sub", v);
809 }
810 }
811
visit(const MulPtr & v)812 void LLVMCodeGenImpl::visit(const MulPtr& v) {
813 v->lhs()->accept(this);
814 auto lhs = this->value_;
815 bool lfp = lhs->getType()->isFPOrFPVectorTy();
816 v->rhs()->accept(this);
817 auto rhs = this->value_;
818 bool rfp = rhs->getType()->isFPOrFPVectorTy();
819
820 // TODO: Handle arg promotion.
821 if (lfp && rfp) {
822 value_ = irb_.CreateFMul(lhs, rhs);
823 } else if (!lfp && !rfp) {
824 value_ = irb_.CreateMul(lhs, rhs);
825 } else {
826 throw malformed_input("llvm_codegen: bad type in Mul", v);
827 }
828 }
829
visit(const DivPtr & v)830 void LLVMCodeGenImpl::visit(const DivPtr& v) {
831 v->lhs()->accept(this);
832 auto lhs = this->value_;
833 bool lfp = lhs->getType()->isFPOrFPVectorTy();
834 v->rhs()->accept(this);
835 auto rhs = this->value_;
836 bool rfp = rhs->getType()->isFPOrFPVectorTy();
837
838 // TODO: Handle arg promotion.
839 if (lfp && rfp) {
840 value_ = irb_.CreateFDiv(lhs, rhs);
841 } else if (!lfp && !rfp) {
842 value_ = irb_.CreateSDiv(lhs, rhs);
843 } else {
844 throw malformed_input("llvm_codegen: bad type in Div", v);
845 }
846 }
847
visit(const AndPtr & v)848 void LLVMCodeGenImpl::visit(const AndPtr& v) {
849 v->lhs()->accept(this);
850 auto lhs = this->value_;
851 bool lfp = lhs->getType()->isFPOrFPVectorTy();
852 v->rhs()->accept(this);
853 auto rhs = this->value_;
854 bool rfp = rhs->getType()->isFPOrFPVectorTy();
855
856 if (!lfp && !rfp) {
857 value_ = irb_.CreateAnd(lhs, rhs);
858 } else {
859 throw malformed_input("llvm_codegen: bad type in And", v);
860 }
861 }
862
visit(const OrPtr & v)863 void LLVMCodeGenImpl::visit(const OrPtr& v) {
864 v->lhs()->accept(this);
865 auto lhs = this->value_;
866 bool lfp = lhs->getType()->isFPOrFPVectorTy();
867 v->rhs()->accept(this);
868 auto rhs = this->value_;
869 bool rfp = rhs->getType()->isFPOrFPVectorTy();
870
871 if (!lfp && !rfp) {
872 value_ = irb_.CreateOr(lhs, rhs);
873 } else {
874 throw malformed_input("llvm_codegen: bad type in Or", v);
875 }
876 }
877
visit(const XorPtr & v)878 void LLVMCodeGenImpl::visit(const XorPtr& v) {
879 v->lhs()->accept(this);
880 auto lhs = this->value_;
881 bool lfp = lhs->getType()->isFPOrFPVectorTy();
882 v->rhs()->accept(this);
883 auto rhs = this->value_;
884 bool rfp = rhs->getType()->isFPOrFPVectorTy();
885
886 if (!lfp && !rfp) {
887 value_ = irb_.CreateXor(lhs, rhs);
888 } else {
889 throw malformed_input("llvm_codegen: bad type in Xor", v);
890 }
891 }
892
visit(const LshiftPtr & v)893 void LLVMCodeGenImpl::visit(const LshiftPtr& v) {
894 v->lhs()->accept(this);
895 auto lhs = this->value_;
896 bool lfp = lhs->getType()->isFPOrFPVectorTy();
897 v->rhs()->accept(this);
898 auto rhs = this->value_;
899 bool rfp = rhs->getType()->isFPOrFPVectorTy();
900
901 if (!lfp && !rfp) {
902 value_ = irb_.CreateShl(lhs, rhs);
903 } else {
904 throw malformed_input("llvm_codegen: bad type in Lshift", v);
905 }
906 }
907
visit(const RshiftPtr & v)908 void LLVMCodeGenImpl::visit(const RshiftPtr& v) {
909 v->lhs()->accept(this);
910 auto lhs = this->value_;
911 bool lfp = lhs->getType()->isFPOrFPVectorTy();
912 v->rhs()->accept(this);
913 auto rhs = this->value_;
914 bool rfp = rhs->getType()->isFPOrFPVectorTy();
915
916 if (!lfp && !rfp) {
917 if (v->lhs()->dtype().is_signed()) {
918 value_ = irb_.CreateAShr(lhs, rhs);
919 } else {
920 value_ = irb_.CreateLShr(lhs, rhs);
921 }
922 } else {
923 throw malformed_input("llvm_codegen: bad type in Rshift", v);
924 }
925 }
926
visit(const ModPtr & v)927 void LLVMCodeGenImpl::visit(const ModPtr& v) {
928 v->lhs()->accept(this);
929 auto lhs = this->value_;
930 bool lfp = lhs->getType()->isFPOrFPVectorTy();
931 v->rhs()->accept(this);
932 auto rhs = this->value_;
933 bool rfp = rhs->getType()->isFPOrFPVectorTy();
934
935 if (!lfp && !rfp) {
936 value_ = irb_.CreateSRem(lhs, rhs);
937 } else {
938 throw malformed_input("llvm_codegen: bad type in Mod", v);
939 }
940 }
941
visit(const MaxPtr & v)942 void LLVMCodeGenImpl::visit(const MaxPtr& v) {
943 v->lhs()->accept(this);
944 auto lhs = this->value_;
945 v->rhs()->accept(this);
946 auto rhs = this->value_;
947
948 if (v->dtype().is_integral()) {
949 auto icmp = v->dtype().is_signed() ? irb_.CreateICmpSGT(lhs, rhs)
950 : irb_.CreateICmpUGT(lhs, rhs);
951 value_ = irb_.CreateSelect(icmp, lhs, rhs);
952 return;
953 }
954
955 value_ = irb_.CreateSelect(
956 irb_.CreateFCmp(
957 llvm::FCmpInst::FCMP_UNO,
958 lhs,
959 llvm::ConstantFP::get(lhs->getType(), 0.0)),
960 lhs,
961 irb_.CreateSelect(
962 irb_.CreateFCmp(llvm::FCmpInst::FCMP_OGT, lhs, rhs), lhs, rhs));
963 }
964
visit(const MinPtr & v)965 void LLVMCodeGenImpl::visit(const MinPtr& v) {
966 v->lhs()->accept(this);
967 auto lhs = this->value_;
968 v->rhs()->accept(this);
969 auto rhs = this->value_;
970 if (v->dtype().is_integral()) {
971 auto icmp = v->dtype().is_signed() ? irb_.CreateICmpSLT(lhs, rhs)
972 : irb_.CreateICmpULT(lhs, rhs);
973 value_ = irb_.CreateSelect(icmp, lhs, rhs);
974 return;
975 }
976
977 value_ = irb_.CreateSelect(
978 irb_.CreateFCmp(
979 llvm::FCmpInst::FCMP_UNO,
980 lhs,
981 llvm::ConstantFP::get(lhs->getType(), 0.0)),
982 lhs,
983 irb_.CreateSelect(
984 irb_.CreateFCmp(llvm::FCmpInst::FCMP_OLT, lhs, rhs), lhs, rhs));
985 }
986
visit(const CompareSelectPtr & v)987 void LLVMCodeGenImpl::visit(const CompareSelectPtr& v) {
988 auto genUnbiased = [this, v]() -> llvm::Value* {
989 v->lhs()->accept(this);
990 auto lhs = this->value_;
991 v->rhs()->accept(this);
992 auto rhs = this->value_;
993 v->ret_val1()->accept(this);
994 auto retval1 = this->value_;
995 v->ret_val2()->accept(this);
996 auto retval2 = this->value_;
997
998 auto type_used = v->lhs()->dtype().scalar_type();
999
1000 llvm::Value* cmp_;
1001 CompareSelectOperation cmp_op_ = v->compare_select_op();
1002
1003 if (c10::isIntegralType(type_used, true)) {
1004 cmp_ = irb_.CreateICmp(
1005 llvm_comparison_predicate(cmp_op_, type_used), lhs, rhs);
1006 } else if (c10::isFloatingType(type_used)) {
1007 cmp_ = irb_.CreateFCmp(llvm_fp_comparison_predicate(cmp_op_), lhs, rhs);
1008 } else {
1009 throw std::runtime_error("invalid type for CompareSelect");
1010 }
1011
1012 return irb_.CreateSelect(cmp_, retval1, retval2);
1013 };
1014
1015 auto genBiased = [this, v]() -> llvm::Value* {
1016 v->lhs()->accept(this);
1017 auto lhs = this->value_;
1018 v->rhs()->accept(this);
1019 auto rhs = this->value_;
1020
1021 auto cmp_type = v->lhs()->dtype().scalar_type();
1022 auto cmp_op = v->compare_select_op();
1023 llvm::Value* cmp;
1024
1025 if (c10::isIntegralType(cmp_type, true)) {
1026 cmp = irb_.CreateICmp(
1027 llvm_comparison_predicate(cmp_op, cmp_type), lhs, rhs);
1028 } else if (c10::isFloatingType(cmp_type)) {
1029 cmp = irb_.CreateFCmp(llvm_fp_comparison_predicate(cmp_op), lhs, rhs);
1030 } else {
1031 throw std::runtime_error("invalid type for CompareSelect");
1032 }
1033
1034 auto lanes = v->lhs()->dtype().lanes();
1035 if (lanes > 1) {
1036 auto maskType = llvm::Type::getIntNTy(getContext(), lanes);
1037 auto zero = llvm::ConstantInt::get(maskType, 0);
1038 auto mask = irb_.CreateBitOrPointerCast(cmp, maskType);
1039 cmp = irb_.CreateICmpNE(mask, zero);
1040 }
1041
1042 auto then_block = llvm::BasicBlock::Create(getContext(), "then", fn_);
1043 auto else_block = llvm::BasicBlock::Create(getContext(), "else", fn_);
1044 auto end_block = llvm::BasicBlock::Create(getContext(), "block", fn_);
1045 constexpr int32_t total_weight = 100000;
1046 auto true_weight = v->bias() == kLikely ? total_weight : 0;
1047 auto false_weight = total_weight - true_weight;
1048 irb_.CreateCondBr(
1049 cmp,
1050 then_block,
1051 else_block,
1052 llvm::MDBuilder(getContext())
1053 .createBranchWeights(true_weight, false_weight));
1054
1055 irb_.SetInsertPoint(then_block);
1056 v->ret_val1()->accept(this);
1057 llvm::Value* then_val = value_;
1058 then_block = irb_.GetInsertBlock();
1059 irb_.CreateBr(end_block);
1060
1061 irb_.SetInsertPoint(else_block);
1062 v->ret_val2()->accept(this);
1063 llvm::Value* else_val = value_;
1064 else_block = irb_.GetInsertBlock();
1065 irb_.CreateBr(end_block);
1066
1067 irb_.SetInsertPoint(end_block);
1068 llvm::PHINode* phi = irb_.CreatePHI(then_val->getType(), 2);
1069 phi->addIncoming(then_val, then_block);
1070 phi->addIncoming(else_val, else_block);
1071 return phi;
1072 };
1073
1074 value_ = v->bias() == kUnbiased ? genUnbiased() : genBiased();
1075 }
1076
1077 template <typename T>
1078 typename std::enable_if<std::is_integral<T>::value, llvm::Value*>::type
getFromType(llvm::Type * type,T value)1079 getFromType(llvm::Type* type, T value) {
1080 return llvm::ConstantInt::get(type, value, std::is_signed<T>::value);
1081 }
1082
1083 template <typename T>
1084 typename std::enable_if<std::is_floating_point<T>::value, llvm::Value*>::type
getFromType(llvm::Type * type,T value)1085 getFromType(llvm::Type* type, T value) {
1086 return llvm::ConstantFP::get(type, value);
1087 }
1088
1089 #define IMM_VISIT_DECLARE(Type, Name) \
1090 void LLVMCodeGenImpl::visit(const Name##ImmPtr& v) { \
1091 value_ = getFromType<Type>(Name##Ty_, v->value()); \
1092 }
1093 AT_FORALL_SCALAR_TYPES(IMM_VISIT_DECLARE);
1094 #undef IMM_VISIT_DECLARE
1095
visit(const HalfImmPtr & v)1096 void LLVMCodeGenImpl::visit(const HalfImmPtr& v) {
1097 value_ = llvm::ConstantFP::get(HalfTy_, v->value());
1098 }
1099
visit(const BFloat16ImmPtr & v)1100 void LLVMCodeGenImpl::visit(const BFloat16ImmPtr& v) {
1101 value_ = llvm::ConstantInt::get(ShortTy_, v->value().x);
1102 }
1103
visit(const BoolImmPtr & v)1104 void LLVMCodeGenImpl::visit(const BoolImmPtr& v) {
1105 value_ = llvm::ConstantInt::get(BoolTy_, v->value());
1106 }
1107
llvmTypeToVec(llvm::Type * type,int lanes)1108 static llvm::Type* llvmTypeToVec(llvm::Type* type, int lanes) {
1109 if (lanes > 1) {
1110 return llvm::VectorType::get(type, ElementCount(lanes));
1111 } else {
1112 return type;
1113 }
1114 }
1115
visit(const CastPtr & v)1116 void LLVMCodeGenImpl::visit(const CastPtr& v) {
1117 v->src_value()->accept(this);
1118
1119 auto dst_type = v->dtype().scalar_type();
1120 auto src_type = v->src_value()->dtype().scalar_type();
1121 bool is_to_bf16 = (dst_type == c10::kBFloat16);
1122 bool is_to_float = (dst_type == c10::kFloat);
1123 bool is_from_bf16 = (src_type == c10::kBFloat16);
1124 bool is_from_float = (src_type == c10::kFloat);
1125
1126 bool cast_from_bf16_to_fp32 = is_from_bf16 && is_to_float;
1127 bool cast_from_fp32_to_bf16 = is_from_float && is_to_bf16;
1128 bool non_bf16_cast = (!is_to_bf16) && (!is_from_bf16);
1129 bool valid_bf16_cast = cast_from_bf16_to_fp32 || cast_from_fp32_to_bf16;
1130 TORCH_CHECK(
1131 valid_bf16_cast || non_bf16_cast,
1132 "Cast is not implemented for the conversion between ",
1133 src_type,
1134 " and ",
1135 dst_type,
1136 ".");
1137
1138 llvm::Type* dstType =
1139 llvmTypeToVec(dtypeToLLVM(v->dtype()), v->dtype().lanes());
1140 llvm::Type* srcType = dtypeToLLVM(v->src_value()->dtype());
1141
1142 if (srcType == dstType) {
1143 // do nothing.
1144 return;
1145 }
1146
1147 bool destUnsigned = v->dtype().scalar_type() == ScalarType::Byte ||
1148 v->dtype().scalar_type() == ScalarType::QUInt8 ||
1149 v->dtype().scalar_type() == ScalarType::Bool;
1150 bool srcUnsigned =
1151 v->src_value()->dtype().scalar_type() == ScalarType::Byte ||
1152 v->src_value()->dtype().scalar_type() == ScalarType::QUInt8 ||
1153 v->src_value()->dtype().scalar_type() == ScalarType::Bool;
1154
1155 // Scalar casts
1156 if (is_from_bf16) {
1157 // Shift the BF16 value left by 16bits and then bit cast the shifted value
1158 // to FP32.
1159 // FP32_VAL = BF16_VAL << 16
1160 auto lans = v->dtype().lanes();
1161 value_ = irb_.CreateZExt(value_, llvmTypeToVec(IntTy_, lans));
1162 auto vec_shl_val = toVec(llvm::ConstantInt::get(IntTy_, 16), lans);
1163 value_ = irb_.CreateShl(value_, vec_shl_val);
1164 value_ = irb_.CreateBitOrPointerCast(value_, llvmTypeToVec(FloatTy_, lans));
1165 return;
1166 }
1167
1168 if (is_to_bf16) {
1169 // Convert the FP32 value by RNE(Rounding to Nearest Even). Algorithm is as
1170 // follows:
1171 // STEP1: U32_VAL = BITCAST(F32_VAL)
1172 // STEP2: U32_VAL_TMP = U32_VAL >> 16
1173 // STEP3: U32_VAL_TMP = U32_VAL_TMP & 1
1174 // STEP4: ROUNDING_BIAS = U32_VAL_TMP + UINT32(0x7FFF)
1175 // STEP5: U32_VAL_TMP = U32_VAL + ROUNDING_BIAS
1176 // STEP6: BF16_VAL = static_cast<UINT16>(U32_VAL_TMP >> 16)
1177 auto lans = v->src_value()->dtype().lanes();
1178 auto shift_len = llvm::ConstantInt::get(IntTy_, 16);
1179 auto one = llvm::ConstantInt::get(ShortTy_, 1);
1180 auto rounding_bias = llvm::ConstantInt::get(ShortTy_, 0x7FFF);
1181 auto bf16_nan = llvm::ConstantInt::get(ShortTy_, 0xFFFF);
1182
1183 auto mask = irb_.CreateFCmpOEQ(value_, value_);
1184 // STEP1: U32_VAL = BITCAST(F32_VAL)
1185 auto fp32_i32_value =
1186 irb_.CreateBitOrPointerCast(value_, llvmTypeToVec(IntTy_, lans));
1187 // STEP2: U32_VAL_TMP = (U32_VAL >> 16)
1188 value_ = irb_.CreateLShr(fp32_i32_value, toVec(shift_len, lans));
1189 value_ = irb_.CreateTrunc(value_, llvmTypeToVec(ShortTy_, lans));
1190 // STEP3: U32_VAL_TMP = U32_VAL_TMP & 1
1191 value_ = irb_.CreateAnd(value_, toVec(one, lans));
1192 // STEP4: ROUNDING_BIAS = U32_VAL_TMP + UINT32(0x7FFF)
1193 value_ = irb_.CreateAdd(value_, toVec(rounding_bias, lans));
1194 value_ = irb_.CreateZExt(value_, llvmTypeToVec(IntTy_, lans));
1195 // STEP5: U32_VAL_TMP = U32_VAL + ROUNDING_BIAS
1196 value_ = irb_.CreateAdd(value_, fp32_i32_value);
1197 // STEP6: BF16_VAL = static_cast<UINT16>(U32_VAL_TMP >> 16)
1198 value_ = irb_.CreateLShr(value_, toVec(shift_len, lans));
1199 value_ = irb_.CreateTrunc(value_, llvmTypeToVec(ShortTy_, lans));
1200 value_ = irb_.CreateBitOrPointerCast(value_, llvmTypeToVec(ShortTy_, lans));
1201 // If the value is NaN, return BF16 NaN.
1202 value_ = irb_.CreateSelect(mask, value_, toVec(bf16_nan, lans));
1203 return;
1204 }
1205
1206 if (srcType->isFPOrFPVectorTy()) {
1207 if (dstType->isFPOrFPVectorTy()) {
1208 // as with eager, convert from Double -> Half by Converting to Float then
1209 // Half. TODO: __truncdfhf2
1210 if (v->dtype().scalar_type() == ScalarType::Half &&
1211 v->src_value()->dtype().scalar_type() == ScalarType::Double) {
1212 value_ = irb_.CreateFPCast(
1213 value_, llvmTypeToVec(FloatTy_, v->dtype().lanes()));
1214 }
1215 value_ = irb_.CreateFPCast(value_, dstType);
1216 } else if (dstType->isIntOrIntVectorTy()) {
1217 // Strictly casting from Float -> i8 doesnt give correct results
1218 // set one bit true if the input float is not 0
1219 if (v->dtype().scalar_type() == ScalarType::Bool) {
1220 llvm::Value* zero =
1221 toVec(llvm::ConstantFP::get(srcType, 0.), v->dtype().lanes());
1222 value_ = irb_.CreateFCmp(llvm::FCmpInst::FCMP_UNE, value_, zero);
1223 value_ = irb_.CreateICmpEQ(
1224 value_, llvm::ConstantInt::get(value_->getType(), 1));
1225 value_ = irb_.CreateIntCast(value_, dstType, !destUnsigned);
1226 return;
1227 }
1228
1229 if (destUnsigned) {
1230 value_ = irb_.CreateFPToUI(value_, dstType);
1231 } else {
1232 value_ = irb_.CreateFPToSI(value_, dstType);
1233 }
1234 } else {
1235 throw unimplemented_lowering(v);
1236 }
1237 return;
1238 }
1239
1240 if (!srcType->isIntOrIntVectorTy()) {
1241 throw unimplemented_lowering(v);
1242 }
1243 if (dstType->isFPOrFPVectorTy()) {
1244 if (srcUnsigned) {
1245 value_ = irb_.CreateUIToFP(value_, dstType);
1246 } else {
1247 value_ = irb_.CreateSIToFP(value_, dstType);
1248 }
1249 } else if (dstType->isIntOrIntVectorTy()) {
1250 // Ensure bool true value is exactly one, since we convert to int
1251 // from bool by zero extending the int8
1252 if (v->dtype().scalar_type() == ScalarType::Bool) {
1253 llvm::Value* zero =
1254 toVec(llvm::ConstantInt::get(srcType, 0), v->dtype().lanes());
1255 value_ = irb_.CreateICmpNE(value_, zero);
1256 }
1257 value_ = irb_.CreateIntCast(value_, dstType, !destUnsigned);
1258 } else {
1259 throw unimplemented_lowering(v);
1260 }
1261 }
1262
visit(const BitCastPtr & v)1263 void LLVMCodeGenImpl::visit(const BitCastPtr& v) {
1264 v->src_value()->accept(this);
1265
1266 llvm::Type* dstType = dtypeToLLVM(v->dtype());
1267 if (v->dtype().lanes() > 1) {
1268 dstType = llvm::VectorType::get(dstType, ElementCount(v->dtype().lanes()));
1269 }
1270 llvm::Type* srcType = dtypeToLLVM(v->src_value()->dtype());
1271
1272 if (srcType == dstType) {
1273 // do nothing.
1274 return;
1275 }
1276
1277 TORCH_CHECK(llvm::CastInst::isBitCastable(
1278 srcType->getScalarType(), dstType->getScalarType()));
1279 value_ = irb_.CreateBitOrPointerCast(value_, dstType);
1280 }
1281
visit(const VarPtr & v)1282 void LLVMCodeGenImpl::visit(const VarPtr& v) {
1283 value_ = varToValue(v);
1284 }
1285
varToValue(VarPtr v)1286 llvm::Value* LLVMCodeGenImpl::varToValue(VarPtr v) {
1287 // It is possible for v to be in both varToVal_ and varToArgs.
1288 // In that case, varToVal_ takes precedence.
1289 if (varToVal_.count(v)) {
1290 return varToVal_.at(v);
1291 } else if (varToArg_.count(v)) {
1292 auto idx = varToArg_.at(v);
1293 auto arg = fn_->arg_begin() + idx;
1294 return arg;
1295 }
1296 return nullptr;
1297 }
1298
replaceVarMapping(const std::vector<VarPtr> & vars,const std::vector<llvm::Value * > & vals)1299 void LLVMCodeGenImpl::replaceVarMapping(
1300 const std::vector<VarPtr>& vars,
1301 const std::vector<llvm::Value*>& vals) {
1302 TORCH_CHECK(vars.size() == vals.size());
1303 for (const auto i : c10::irange(vars.size())) {
1304 VarPtr var = vars[i];
1305 llvm::Value* val = vals[i];
1306 if (val) {
1307 varToVal_[var] = val;
1308 } else {
1309 varToVal_.erase(var);
1310 }
1311 }
1312 }
1313
visit(const RampPtr & v)1314 void LLVMCodeGenImpl::visit(const RampPtr& v) {
1315 v->base()->accept(this);
1316 auto base = this->value_;
1317 v->stride()->accept(this);
1318 auto stride = this->value_;
1319 int lanes = v->lanes();
1320
1321 if (llvm::ConstantInt* const_stride =
1322 llvm::dyn_cast<llvm::ConstantInt>(stride)) {
1323 std::vector<llvm::Constant*> vals = {
1324 llvm::ConstantInt::get(base->getType(), 0)};
1325 for (int i = 1; i < lanes; ++i) {
1326 vals.push_back(llvm::ConstantExpr::getAdd(vals.back(), const_stride));
1327 }
1328
1329 llvm::Value* offsets = llvm::ConstantVector::get(vals);
1330 llvm::Value* splat = irb_.CreateVectorSplat(lanes, base);
1331 value_ = irb_.CreateAdd(splat, offsets);
1332 return;
1333 }
1334
1335 llvm::Type* vecType = nullptr;
1336 auto element_count = ElementCount(lanes);
1337 switch (v->dtype().scalar_type()) {
1338 #define TYPE_CASE(_1, Name) \
1339 case ScalarType::Name: \
1340 vecType = llvm::VectorType::get(Name##Ty_, element_count); \
1341 break;
1342 AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
1343 #undef TYPE_CASE
1344 case ScalarType::QInt8:
1345 vecType = llvm::VectorType::get(CharTy_, element_count);
1346 break;
1347 case ScalarType::QUInt8:
1348 vecType = llvm::VectorType::get(ByteTy_, element_count);
1349 break;
1350 case ScalarType::BFloat16:
1351 vecType = llvm::VectorType::get(ShortTy_, element_count);
1352 break;
1353 default:
1354 throw std::runtime_error("invalid dtype in Ramp");
1355 }
1356
1357 value_ = llvm::UndefValue::get(vecType);
1358 for (int i = 0; i < lanes; ++i) {
1359 value_ = irb_.CreateInsertElement(value_, base, i);
1360 base = irb_.CreateAdd(base, stride);
1361 }
1362 }
emitUnmaskedLoad(llvm::Type * ty,llvm::Value * base,llvm::Value * idx)1363 llvm::Value* LLVMCodeGenImpl::emitUnmaskedLoad(
1364 llvm::Type* ty,
1365 llvm::Value* base,
1366 llvm::Value* idx) {
1367 #if LLVM_VERSION_MAJOR >= 15
1368 auto addr = irb_.CreateGEP(ty, base, idx);
1369 return irb_.CreateLoad(ty, addr);
1370 #else
1371 auto addr = irb_.CreateGEP(
1372 base->getType()->getScalarType()->getPointerElementType(), base, idx);
1373 return irb_.CreateLoad(addr->getType()->getPointerElementType(), addr);
1374 #endif
1375 }
1376
emitMaskedLoad(llvm::Type * ty,llvm::Value * base,llvm::Value * idx,llvm::Value * mask)1377 llvm::Value* LLVMCodeGenImpl::emitMaskedLoad(
1378 llvm::Type* ty,
1379 llvm::Value* base,
1380 llvm::Value* idx,
1381 llvm::Value* mask) {
1382 // Create block structure for the masked load.
1383 auto preheader = irb_.GetInsertBlock();
1384 auto condblock = llvm::BasicBlock::Create(getContext(), "cond", fn_);
1385 auto tailblock = llvm::BasicBlock::Create(getContext(), "tail", fn_);
1386
1387 // Test the mask
1388 auto cond = irb_.CreateICmpEQ(mask, llvm::ConstantInt::get(IntTy_, 1));
1389 irb_.CreateCondBr(cond, condblock, tailblock);
1390
1391 // Do the load
1392 irb_.SetInsertPoint(condblock);
1393
1394 #if LLVM_VERSION_MAJOR >= 15
1395 auto addr = irb_.CreateGEP(ty, base, idx);
1396 auto load = irb_.CreateLoad(ty, addr);
1397 #else
1398 auto addr = irb_.CreateGEP(
1399 base->getType()->getScalarType()->getPointerElementType(), base, idx);
1400 auto load = irb_.CreateLoad(addr->getType()->getPointerElementType(), addr);
1401 #endif
1402
1403 irb_.CreateBr(tailblock);
1404
1405 // Merge the masked and unmasked CFG edges
1406 irb_.SetInsertPoint(tailblock);
1407 auto phi = irb_.CreatePHI(load->getType(), 2);
1408 phi->addIncoming(llvm::UndefValue::get(load->getType()), preheader);
1409 phi->addIncoming(load, condblock);
1410
1411 return phi;
1412 }
1413
visit(const LoadPtr & v)1414 void LLVMCodeGenImpl::visit(const LoadPtr& v) {
1415 if (v->dtype().lanes() == 1) {
1416 v->base_handle()->accept(this);
1417 auto base = this->value_;
1418 v->flat_index()->accept(this);
1419 auto idx = this->value_;
1420 value_ = emitUnmaskedLoad(dtypeToLLVM(v->dtype()), base, idx);
1421 return;
1422 }
1423
1424 llvm::Type* loadType = nullptr;
1425
1426 auto element_count = ElementCount(v->dtype().lanes());
1427 switch (v->dtype().scalar_type()) {
1428 #define TYPE_CASE(_1, Name) \
1429 case ScalarType::Name: \
1430 loadType = llvm::VectorType::get(Name##Ty_, element_count); \
1431 break;
1432 AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
1433 #undef TYPE_CASE
1434 case ScalarType::QInt8:
1435 loadType = llvm::VectorType::get(CharTy_, element_count);
1436 break;
1437 case ScalarType::QUInt8:
1438 loadType = llvm::VectorType::get(ByteTy_, element_count);
1439 break;
1440 case ScalarType::BFloat16:
1441 loadType = llvm::VectorType::get(ShortTy_, element_count);
1442 break;
1443 default:
1444 throw std::runtime_error("invalid dtype in Load");
1445 }
1446
1447 // Handle the case where the load is contiguous and unmasked efficiently
1448 auto idx_ramp = to<Ramp>(v->flat_index());
1449 if (idx_ramp) {
1450 auto stride_imm = intValue(idx_ramp->stride());
1451 if (stride_imm && *stride_imm == 1) {
1452 v->base_handle()->accept(this);
1453 auto base = this->value_;
1454 idx_ramp->base()->accept(this);
1455 auto first_idx = this->value_;
1456
1457 #if LLVM_VERSION_MAJOR >= 15
1458 auto addr = irb_.CreateGEP(dtypeToLLVM(v->dtype()), base, first_idx);
1459 #else
1460 auto addr = irb_.CreateGEP(
1461 base->getType()->getScalarType()->getPointerElementType(),
1462 base,
1463 first_idx);
1464 #endif
1465
1466 auto vaddr = irb_.CreateBitOrPointerCast(
1467 addr, llvm::PointerType::get(loadType, 0));
1468 #if LLVM_VERSION_MAJOR >= 12
1469 value_ = irb_.CreateAlignedLoad(loadType, vaddr, llvm::MaybeAlign(4));
1470 #else
1471 value_ = irb_.CreateAlignedLoad(loadType, vaddr, 4);
1472 #endif
1473 return;
1474 }
1475 }
1476
1477 // Fallback to a scalar implementation
1478 v->base_handle()->accept(this);
1479 auto base = this->value_;
1480 v->flat_index()->accept(this);
1481 auto idx = this->value_;
1482
1483 llvm::Value* load = llvm::UndefValue::get(loadType);
1484 for (int i = 0; i < v->dtype().lanes(); ++i) {
1485 auto sub_idx = irb_.CreateExtractElement(idx, i);
1486 llvm::Value* sub_load = nullptr;
1487 sub_load = emitUnmaskedLoad(dtypeToLLVM(v->dtype()), base, sub_idx);
1488 load = irb_.CreateInsertElement(load, sub_load, i);
1489 }
1490
1491 value_ = load;
1492 }
1493
1494 #if LLVM_VERSION_MAJOR >= 15
1495 // Pack the arguments into an aggregate struct for forwarding.
packFuncArgs(const std::vector<llvm::Value * > & func_args)1496 TypedPointer LLVMCodeGenImpl::packFuncArgs(
1497 const std::vector<llvm::Value*>& func_args) {
1498 if (func_args.empty()) {
1499 llvm::PointerType* VoidPtrType = llvm::PointerType::getUnqual(getContext());
1500 return TypedPointer(
1501 VoidPtrType, llvm::ConstantPointerNull::get(VoidPtrType));
1502 }
1503 std::vector<llvm::Type*> arg_types(func_args.size());
1504 for (const auto i : c10::irange(func_args.size())) {
1505 arg_types[i] = func_args[i]->getType();
1506 }
1507 llvm::StructType* packed_type = llvm::StructType::create(arg_types);
1508 llvm::Value* zero = llvm::ConstantInt::get(IntTy_, 0);
1509 llvm::Value* one = llvm::ConstantInt::get(IntTy_, 1);
1510 llvm::Value* packed = irb_.CreateAlloca(packed_type, one);
1511 for (const auto i : c10::irange(func_args.size())) {
1512 llvm::Value* dst_ptr = irb_.CreateInBoundsGEP(
1513 packed_type, packed, {zero, llvm::ConstantInt::get(IntTy_, i)});
1514 irb_.CreateStore(func_args[i], dst_ptr);
1515 }
1516 return TypedPointer(packed_type, packed);
1517 }
1518
1519 // Unpack the aggregate struct into individual arguments.
unpackFuncArgs(TypedPointer packed,int arg_count)1520 std::vector<llvm::Value*> LLVMCodeGenImpl::unpackFuncArgs(
1521 TypedPointer packed,
1522 int arg_count) {
1523 // TODO: extract arg_count from packed.
1524 std::vector<llvm::Value*> func_args(arg_count);
1525 llvm::Value* zero = llvm::ConstantInt::get(IntTy_, 0);
1526 for (const auto i : c10::irange(arg_count)) {
1527 llvm::Type* feild_type = packed.type->getStructElementType(i);
1528 llvm::Value* feild_addr = irb_.CreateInBoundsGEP(
1529 packed.type, packed.addr, {zero, llvm::ConstantInt::get(IntTy_, i)});
1530 func_args[i] = irb_.CreateLoad(feild_type, feild_addr);
1531 }
1532 return func_args;
1533 }
1534 #else
1535 // Pack the arguments into an aggregate struct for forwarding.
packFuncArgs(const std::vector<llvm::Value * > & func_args)1536 llvm::Value* LLVMCodeGenImpl::packFuncArgs(
1537 const std::vector<llvm::Value*>& func_args) {
1538 if (func_args.empty()) {
1539 llvm::PointerType* VoidPtrType = llvm::Type::getInt8PtrTy(getContext());
1540 llvm::Constant* NullPtr = llvm::ConstantPointerNull::get(VoidPtrType);
1541 return NullPtr;
1542 }
1543 std::vector<llvm::Type*> arg_types(func_args.size());
1544 for (const auto i : c10::irange(func_args.size())) {
1545 arg_types[i] = func_args[i]->getType();
1546 }
1547 llvm::StructType* packed_type = llvm::StructType::create(arg_types);
1548 llvm::Value* zero = llvm::ConstantInt::get(IntTy_, 0);
1549 llvm::Value* one = llvm::ConstantInt::get(IntTy_, 1);
1550 llvm::Value* packed = irb_.CreateAlloca(packed_type, one);
1551 for (const auto i : c10::irange(func_args.size())) {
1552 llvm::Value* dst_ptr = irb_.CreateInBoundsGEP(
1553 packed_type, packed, {zero, llvm::ConstantInt::get(IntTy_, i)});
1554 irb_.CreateStore(func_args[i], dst_ptr);
1555 }
1556 return packed;
1557 }
1558
1559 // Unpack the aggregate struct into individual arguments.
unpackFuncArgs(llvm::Value * packed,int arg_count)1560 std::vector<llvm::Value*> LLVMCodeGenImpl::unpackFuncArgs(
1561 llvm::Value* packed,
1562 int arg_count) {
1563 // TODO: extract arg_count from packed.
1564 std::vector<llvm::Value*> func_args(arg_count);
1565 llvm::Value* zero = llvm::ConstantInt::get(IntTy_, 0);
1566 for (const auto i : c10::irange(arg_count)) {
1567 llvm::Type* packed_type = packed->getType()->getPointerElementType();
1568 llvm::Value* dst_ptr = irb_.CreateInBoundsGEP(
1569 packed_type, packed, {zero, llvm::ConstantInt::get(IntTy_, i)});
1570 func_args[i] =
1571 irb_.CreateLoad(dst_ptr->getType()->getPointerElementType(), dst_ptr);
1572 }
1573 return func_args;
1574 }
1575 #endif
1576
1577 // Lower the parallel for-loop.
1578 // * Move the body into its own closure.
1579 // * Identify var across the boundary into arguments and forward them.
1580 // * Send the closure and range to the dispatcher for execution.
processParallelFor(ForPtr v)1581 void LLVMCodeGenImpl::processParallelFor(ForPtr v) {
1582 // Create "start" and "stop" values.
1583 v->start()->accept(this);
1584 auto start = this->value_;
1585 v->stop()->accept(this);
1586 auto stop = this->value_;
1587
1588 // The Vars that need to be forward in the body closure.
1589 std::vector<VarPtr> body_arg_vars;
1590 // Corresponding Value* that was used in the old body for the caller.
1591 std::vector<llvm::Value*> body_caller_vals;
1592 // Corresponding Value* that will be used in the new body closure.
1593 std::vector<llvm::Value*> body_closure_args;
1594
1595 // Identify the VarPtr used in the body, and generated outside.
1596 VarFinder var_finder;
1597 v->body()->accept(&var_finder);
1598 auto& vars = var_finder.vars();
1599 for (auto& var : vars) {
1600 if (llvm::Value* value = varToValue(var)) {
1601 body_arg_vars.push_back(var);
1602 body_caller_vals.push_back(value);
1603 }
1604 }
1605
1606 // Pack the arguments in an automatic variable for forwarding.
1607 #if LLVM_VERSION_MAJOR >= 15
1608 TypedPointer packData = packFuncArgs(body_caller_vals);
1609 llvm::Value* packed_caller_args = packData.addr;
1610 #else
1611 llvm::Value* packed_caller_args = packFuncArgs(body_caller_vals);
1612 #endif
1613 // Remember where we are before moving to the new function.
1614 llvm::BasicBlock* old_insert_block = irb_.GetInsertBlock();
1615
1616 // Create the new body closure code.
1617 #if LLVM_VERSION_MAJOR >= 15
1618 auto func_type =
1619 llvm::FunctionType::get(VoidTy_, {LongTy_, OpqPtrTy_}, false);
1620 #else
1621 auto func_type =
1622 llvm::FunctionType::get(VoidTy_, {LongTy_, Int8PtrTy_}, false);
1623 #endif
1624
1625 llvm::Function* func = llvm::Function::Create(
1626 func_type, llvm::Function::PrivateLinkage, "func", module_.get());
1627 auto func_body = llvm::BasicBlock::Create(getContext(), "func_body", func);
1628 irb_.SetInsertPoint(func_body);
1629 auto args = func->arg_begin();
1630 llvm::Value* index = args++;
1631 llvm::Value* packed_func_args_raw = args++;
1632 llvm::Value* packed_func_args = irb_.CreatePointerCast(
1633 packed_func_args_raw, packed_caller_args->getType());
1634
1635 // Unpack the arguments from the opaque buffer.
1636 if (v->var()->dtype().scalar_type() != c10::kLong) {
1637 index = irb_.CreateIntCast(
1638 index, dtypeToLLVM(v->var()->dtype()), v->var()->dtype().is_signed());
1639 }
1640 #if LLVM_VERSION_MAJOR >= 15
1641 body_closure_args =
1642 unpackFuncArgs({packData.type, packed_func_args}, body_arg_vars.size());
1643 #else
1644 body_closure_args = unpackFuncArgs(packed_func_args, body_arg_vars.size());
1645 #endif
1646 // Set the codegen to the new func.
1647 // TODO: this should be replaced by RAII wrappers.
1648 varToVal_[v->var()] = index;
1649 replaceVarMapping(body_arg_vars, body_closure_args);
1650 llvm::Function* old_fn = fn_;
1651 fn_ = func;
1652 if (v->body()) {
1653 v->body()->accept(this);
1654 }
1655 // Restore back to the previous fn_
1656 fn_ = old_fn;
1657 irb_.CreateRet(nullptr);
1658 replaceVarMapping(body_arg_vars, body_caller_vals);
1659 varToVal_.erase(v->var());
1660
1661 // Points back to the original block and generate the callee code.
1662 irb_.SetInsertPoint(old_insert_block);
1663
1664 #if LLVM_VERSION_MAJOR >= 15
1665 llvm::Value* packed_caller_args_ptr =
1666 irb_.CreatePointerCast(packed_caller_args, OpqPtrTy_);
1667 llvm::Value* func_value = irb_.CreatePointerCast(func, OpqPtrTy_);
1668 llvm::FunctionType* dispatcher_fntype = llvm::FunctionType::get(
1669 VoidTy_, {OpqPtrTy_, LongTy_, LongTy_, OpqPtrTy_}, false);
1670 #else
1671 llvm::Value* packed_caller_args_ptr =
1672 irb_.CreatePointerCast(packed_caller_args, Int8PtrTy_);
1673 llvm::Value* func_value = irb_.CreatePointerCast(func, Int8PtrTy_);
1674 llvm::FunctionType* dispatcher_fntype = llvm::FunctionType::get(
1675 VoidTy_, {Int8PtrTy_, LongTy_, LongTy_, Int8PtrTy_}, false);
1676 #endif
1677
1678 FunctionCallee dispatcher_callee =
1679 module_->getOrInsertFunction("DispatchParallel", dispatcher_fntype);
1680 llvm::Function* dispatcher =
1681 llvm::cast<llvm::Function>(dispatcher_callee.getCallee());
1682 dispatcher->addFnAttr(llvm::Attribute::NoUnwind);
1683 start = irb_.CreateIntCast(start, LongTy_, true);
1684 stop = irb_.CreateIntCast(stop, LongTy_, true);
1685 irb_.CreateCall(
1686 dispatcher, {func_value, start, stop, packed_caller_args_ptr});
1687 value_ = llvm::ConstantInt::get(IntTy_, 0);
1688 }
1689
visit(const ForPtr & v)1690 void LLVMCodeGenImpl::visit(const ForPtr& v) {
1691 if (v->is_parallel()) {
1692 processParallelFor(v);
1693 return;
1694 }
1695
1696 // Create "start" and "stop" values.
1697 v->start()->accept(this);
1698 auto start = this->value_;
1699 v->stop()->accept(this);
1700 auto stop = this->value_;
1701
1702 // Create block for loop condition test.
1703 auto preheader = irb_.GetInsertBlock();
1704 auto condBlock = llvm::BasicBlock::Create(getContext(), "cond", fn_);
1705 irb_.CreateBr(condBlock);
1706 irb_.SetInsertPoint(condBlock);
1707
1708 // Set up phi node for index variable.
1709 auto idx = irb_.CreatePHI(start->getType(), 2);
1710 idx->addIncoming(start, preheader);
1711 if (!varToVal_.count(v->var())) {
1712 varToVal_.emplace(v->var(), idx);
1713 } else {
1714 throw std::runtime_error("var should not exist before");
1715 }
1716
1717 // Create the body and exit blocks.
1718 auto body = llvm::BasicBlock::Create(getContext(), "body", fn_);
1719 auto exit = llvm::BasicBlock::Create(getContext(), "exit", fn_);
1720
1721 // Create the stop condition.
1722 auto cond = irb_.CreateICmpSLT(idx, stop);
1723 irb_.CreateCondBr(cond, body, exit);
1724
1725 // Codegen the body.
1726 irb_.SetInsertPoint(body);
1727 if (v->body()) {
1728 v->body()->accept(this);
1729 }
1730 // "Body" block may have changed if we generated nested control flow.
1731 body = irb_.GetInsertBlock();
1732
1733 // Increment the index variable and branch back to loop test.
1734 auto inc =
1735 irb_.CreateAdd(idx, llvm::ConstantInt::getSigned(start->getType(), 1));
1736 irb_.CreateBr(condBlock);
1737 idx->addIncoming(inc, body);
1738
1739 // Exit the loop.
1740 irb_.SetInsertPoint(exit);
1741
1742 varToVal_.erase(v->var());
1743 value_ = llvm::ConstantInt::get(IntTy_, 0);
1744 }
1745
visit(const BlockPtr & v)1746 void LLVMCodeGenImpl::visit(const BlockPtr& v) {
1747 BlockPtr last = scope_;
1748 scope_ = v;
1749
1750 for (StmtPtr s : *v) {
1751 s->accept(this);
1752 }
1753
1754 scope_ = last;
1755
1756 auto it = scopeToVar_.find(v);
1757 if (it != scopeToVar_.end()) {
1758 for (VarPtr e : it->second) {
1759 if (varToVal_.erase(e) != 1) {
1760 throw std::runtime_error("erasing var that doesn't exist");
1761 }
1762 }
1763 }
1764 }
1765
emitUnmaskedStore(llvm::Type * ty,llvm::Value * base,llvm::Value * idx,llvm::Value * val)1766 void LLVMCodeGenImpl::emitUnmaskedStore(
1767 llvm::Type* ty,
1768 llvm::Value* base,
1769 llvm::Value* idx,
1770 llvm::Value* val) {
1771 #if LLVM_VERSION_MAJOR >= 15
1772 auto addr = irb_.CreateGEP(ty, base, idx);
1773 #else
1774 auto addr = irb_.CreateGEP(
1775 base->getType()->getScalarType()->getPointerElementType(), base, idx);
1776 #endif
1777
1778 irb_.CreateStore(val, addr);
1779 }
1780
emitMaskedStore(llvm::Type * ty,llvm::Value * base,llvm::Value * idx,llvm::Value * mask,llvm::Value * val)1781 void LLVMCodeGenImpl::emitMaskedStore(
1782 llvm::Type* ty,
1783 llvm::Value* base,
1784 llvm::Value* idx,
1785 llvm::Value* mask,
1786 llvm::Value* val) {
1787 // Create block structure for the masked store.
1788 auto condblock = llvm::BasicBlock::Create(getContext(), "cond", fn_);
1789 auto tailblock = llvm::BasicBlock::Create(getContext(), "tail", fn_);
1790
1791 // Test the mask
1792 auto cond = irb_.CreateICmpEQ(mask, llvm::ConstantInt::get(IntTy_, 1));
1793 irb_.CreateCondBr(cond, condblock, tailblock);
1794
1795 // Do the store
1796 irb_.SetInsertPoint(condblock);
1797
1798 #if LLVM_VERSION_MAJOR >= 15
1799 auto addr = irb_.CreateGEP(ty, base, idx);
1800 #else
1801 auto addr = irb_.CreateGEP(
1802 base->getType()->getScalarType()->getPointerElementType(), base, idx);
1803 #endif
1804
1805 irb_.CreateStore(val, addr);
1806 irb_.CreateBr(tailblock);
1807
1808 // Merge the masked and unmasked CFG edges
1809 irb_.SetInsertPoint(tailblock);
1810 }
1811
visit(const StorePtr & v)1812 void LLVMCodeGenImpl::visit(const StorePtr& v) {
1813 if (v->value()->dtype().lanes() == 1) {
1814 v->base_handle()->accept(this);
1815 auto base = this->value_;
1816 v->flat_index()->accept(this);
1817 auto idx = this->value_;
1818 v->value()->accept(this);
1819 auto val = this->value_;
1820
1821 emitUnmaskedStore(dtypeToLLVM(v->value()->dtype()), base, idx, val);
1822 value_ = llvm::ConstantInt::get(IntTy_, 0);
1823 return;
1824 }
1825
1826 v->base_handle()->accept(this);
1827 auto base = this->value_;
1828 v->value()->accept(this);
1829 auto val = this->value_;
1830
1831 // Handle the case where the store is contiguous and unmasked efficiently
1832 auto idx_ramp = to<Ramp>(v->flat_index());
1833 if (idx_ramp) {
1834 auto stride_imm = intValue(idx_ramp->stride());
1835 if (stride_imm && *stride_imm == 1) {
1836 idx_ramp->base()->accept(this);
1837 auto first_idx = value_;
1838
1839 #if LLVM_VERSION_MAJOR >= 15
1840 auto addr =
1841 irb_.CreateGEP(dtypeToLLVM(v->value()->dtype()), base, first_idx);
1842 #else
1843 auto addr = irb_.CreateGEP(
1844 base->getType()->getScalarType()->getPointerElementType(),
1845 base,
1846 first_idx);
1847 #endif
1848
1849 auto vaddr = irb_.CreateBitOrPointerCast(
1850 addr, llvm::PointerType::get(val->getType(), 0));
1851
1852 #if LLVM_VERSION_MAJOR >= 13
1853 irb_.CreateAlignedStore(val, vaddr, llvm::MaybeAlign(4));
1854 #else
1855 irb_.CreateAlignedStore(val, vaddr, 4);
1856 #endif
1857 value_ = llvm::ConstantInt::get(IntTy_, 0);
1858 return;
1859 }
1860 }
1861
1862 v->flat_index()->accept(this);
1863 auto idx = this->value_;
1864
1865 // Fallback to a scalar implementation
1866 for (int i = 0; i < v->value()->dtype().lanes(); ++i) {
1867 auto sub_idx = irb_.CreateExtractElement(idx, i);
1868 auto sub_val = irb_.CreateExtractElement(val, i);
1869 emitUnmaskedStore(dtypeToLLVM(v->value()->dtype()), base, sub_idx, sub_val);
1870 }
1871
1872 value_ = llvm::ConstantInt::get(IntTy_, 0);
1873 }
1874
visit(const BroadcastPtr & v)1875 void LLVMCodeGenImpl::visit(const BroadcastPtr& v) {
1876 v->value()->accept(this);
1877 int lanes = v->lanes();
1878 value_ = irb_.CreateVectorSplat(lanes, value_);
1879 }
1880
visit(const IfThenElsePtr & v)1881 void LLVMCodeGenImpl::visit(const IfThenElsePtr& v) {
1882 v->condition()->accept(this);
1883 llvm::Value* condition = value_;
1884 llvm::Value* c = irb_.CreateICmpNE(
1885 condition, llvm::ConstantInt::get(condition->getType(), 0));
1886
1887 auto then_block = llvm::BasicBlock::Create(getContext(), "then", fn_);
1888 auto else_block = llvm::BasicBlock::Create(getContext(), "else", fn_);
1889 auto end_block = llvm::BasicBlock::Create(getContext(), "block", fn_);
1890 irb_.CreateCondBr(c, then_block, else_block);
1891
1892 irb_.SetInsertPoint(then_block);
1893 v->true_value()->accept(this);
1894 llvm::Value* then_val = value_;
1895 then_block = irb_.GetInsertBlock();
1896 irb_.CreateBr(end_block);
1897
1898 irb_.SetInsertPoint(else_block);
1899 v->false_value()->accept(this);
1900 llvm::Value* else_val = value_;
1901 else_block = irb_.GetInsertBlock();
1902 irb_.CreateBr(end_block);
1903
1904 irb_.SetInsertPoint(end_block);
1905 llvm::PHINode* phi = irb_.CreatePHI(then_val->getType(), 2);
1906 phi->addIncoming(then_val, then_block);
1907 phi->addIncoming(else_val, else_block);
1908 value_ = phi;
1909 }
1910
applyMathFunctionAttributes(llvm::Function * f)1911 static void applyMathFunctionAttributes(llvm::Function* f) {
1912 f->addFnAttr(llvm::Attribute::ReadNone);
1913 f->addFnAttr(llvm::Attribute::NoUnwind);
1914 // TODO: Adding this attr should be correct, but as of LLVM 9.0.1 adding it
1915 // causes some math functions to incorrectly be turned into tail calls.
1916 // f->addFnAttr(llvm::Attribute::Speculatable);
1917 #if LLVM_VERSION_MAJOR >= 9
1918 f->addFnAttr(llvm::Attribute::NoFree);
1919 f->addFnAttr(llvm::Attribute::WillReturn);
1920 #endif
1921 }
1922
toVec(llvm::Value * v,int lanes)1923 llvm::Value* LLVMCodeGenImpl::toVec(llvm::Value* v, int lanes) {
1924 if (lanes > 1) {
1925 return irb_.CreateVectorSplat(lanes, v);
1926 } else {
1927 return v;
1928 }
1929 }
1930
emitIsNan(IntrinsicsPtr v)1931 void LLVMCodeGenImpl::emitIsNan(IntrinsicsPtr v) {
1932 v->param(0)->accept(this);
1933 llvm::Type* dstType = dtypeToLLVM(v->dtype());
1934 if (!v->param(0)->dtype().is_floating_point()) {
1935 value_ = toVec(llvm::ConstantInt::get(dstType, 0), v->dtype().lanes());
1936 } else {
1937 TORCH_INTERNAL_ASSERT(
1938 v->dtype().scalar_type() == ScalarType::Int,
1939 buildErrorMessage(
1940 "Unexpected non-Int dtype of Intrinsics' result value in the fuser."));
1941 auto is_nan = irb_.CreateFCmpUNO(
1942 value_, llvm::ConstantFP::get(value_->getType(), 0.));
1943 if (v->dtype().lanes() > 1) {
1944 dstType =
1945 llvm::VectorType::get(dstType, ElementCount(v->dtype().lanes()));
1946 }
1947 value_ = irb_.CreateIntCast(is_nan, dstType, /*isSigned*/ false);
1948 }
1949 }
1950
wantSleef(const std::string & name)1951 static bool wantSleef(const std::string& name) {
1952 // Using sleef on these ops is slower than libm.
1953 static std::unordered_set<std::string> noSleef = {
1954 "sqrt",
1955 "ceil",
1956 "trunc",
1957 "fabs",
1958 "floor",
1959 "sqrtf",
1960 "ceilf",
1961 "truncf",
1962 "fabsf",
1963 "floorf",
1964 };
1965 return noSleef.find(name) == noSleef.end();
1966 }
1967
getSimdFunction(const std::string & basename,llvm::Type * basetype,Arity arity,int lanes)1968 LLVMCodeGenImpl::SimdCallee LLVMCodeGenImpl::getSimdFunction(
1969 const std::string& basename,
1970 llvm::Type* basetype,
1971 Arity arity,
1972 int lanes) {
1973 std::string name;
1974 llvm::Type* type;
1975 bool useSimd;
1976
1977 // Determine whether to use vectorized intrinsic.
1978 auto const& featureString = jit_->getTargetMachine().getTargetFeatureString();
1979 bool hasAVX = featureString.find("+avx") != llvm::StringRef::npos;
1980 std::string typeSuffix = basetype == DoubleTy_ ? "d" : "";
1981 std::string sleefName =
1982 "Sleef_" + basename + typeSuffix + std::to_string(lanes);
1983 if (wantSleef(basename) && hasAVX && jit_->hasSymbol(sleefName)) {
1984 name = std::move(sleefName);
1985 type = llvm::VectorType::get(basetype, ElementCount(lanes));
1986 useSimd = true;
1987 } else {
1988 name = basename;
1989 type = basetype;
1990 useSimd = false;
1991 }
1992
1993 // Get function to call from name and type.
1994 llvm::FunctionType* fntype;
1995 switch (arity) {
1996 case Unary:
1997 fntype = llvm::FunctionType::get(type, {type}, false);
1998 break;
1999 case Binary:
2000 fntype = llvm::FunctionType::get(type, {type, type}, false);
2001 break;
2002 }
2003 FunctionCallee callee = module_->getOrInsertFunction(name, fntype, {});
2004 applyMathFunctionAttributes(llvm::cast<llvm::Function>(callee.getCallee()));
2005 return SimdCallee{callee.getFunctionType(), callee.getCallee(), useSimd};
2006 }
2007
visit(const IntrinsicsPtr & v)2008 void LLVMCodeGenImpl::visit(const IntrinsicsPtr& v) {
2009 llvm::FunctionType* call_ty = nullptr;
2010 llvm::Value* call_fn = nullptr;
2011 bool call_simd_sleef = false;
2012
2013 if (v->op_type() == kIsNan) {
2014 return emitIsNan(v);
2015 }
2016
2017 if (v->dtype().scalar_type() == ScalarType::Float) {
2018 switch (v->op_type()) {
2019 case kRsqrt: {
2020 v->params().front()->accept(this);
2021 value_ = irb_.CreateUnaryIntrinsic(llvm::Intrinsic::sqrt, value_);
2022 llvm::Value* constant =
2023 toVec(llvm::ConstantFP::get(FloatTy_, 1.0), v->dtype().lanes());
2024 value_ = irb_.CreateFDiv(constant, value_);
2025 return;
2026 } break;
2027
2028 #define SIMD_UNARY_MATH_CASE(enum, name, type) \
2029 case enum: { \
2030 std::tie(call_ty, call_fn, call_simd_sleef) = \
2031 getSimdFunction(name, type, Unary, v->dtype().lanes()); \
2032 } break;
2033 SIMD_UNARY_MATH_CASE(kLog10, "log10f", FloatTy_)
2034 SIMD_UNARY_MATH_CASE(kLog, "logf", FloatTy_)
2035 SIMD_UNARY_MATH_CASE(kLog1p, "log1pf", FloatTy_)
2036 SIMD_UNARY_MATH_CASE(kLog2, "log2f", FloatTy_)
2037 SIMD_UNARY_MATH_CASE(kExp, "expf", FloatTy_)
2038 SIMD_UNARY_MATH_CASE(kCos, "cosf", FloatTy_)
2039 SIMD_UNARY_MATH_CASE(kSin, "sinf", FloatTy_)
2040 SIMD_UNARY_MATH_CASE(kSqrt, "sqrtf", FloatTy_)
2041 SIMD_UNARY_MATH_CASE(kAbs, "fabsf", FloatTy_)
2042 SIMD_UNARY_MATH_CASE(kFloor, "floorf", FloatTy_)
2043 SIMD_UNARY_MATH_CASE(kCeil, "ceilf", FloatTy_)
2044 SIMD_UNARY_MATH_CASE(kTrunc, "truncf", FloatTy_)
2045 SIMD_UNARY_MATH_CASE(kRound, "nearbyint", FloatTy_)
2046 SIMD_UNARY_MATH_CASE(kErf, "erff", FloatTy_)
2047 SIMD_UNARY_MATH_CASE(kErfc, "erfcf", FloatTy_)
2048 SIMD_UNARY_MATH_CASE(kTan, "tanf", FloatTy_)
2049 SIMD_UNARY_MATH_CASE(kAcos, "acosf", FloatTy_)
2050 SIMD_UNARY_MATH_CASE(kAsin, "asinf", FloatTy_)
2051 SIMD_UNARY_MATH_CASE(kAtan, "atanf", FloatTy_)
2052 SIMD_UNARY_MATH_CASE(kCosh, "coshf", FloatTy_)
2053 SIMD_UNARY_MATH_CASE(kSinh, "sinhf", FloatTy_)
2054 SIMD_UNARY_MATH_CASE(kTanh, "tanhf", FloatTy_)
2055 SIMD_UNARY_MATH_CASE(kExpm1, "expm1f", FloatTy_)
2056 SIMD_UNARY_MATH_CASE(kLgamma, "lgammaf", FloatTy_)
2057 #undef SIMD_UNARY_MATH_CASE
2058
2059 #define SIMD_BINARY_MATH_CASE(enum, name, type) \
2060 case enum: { \
2061 std::tie(call_ty, call_fn, call_simd_sleef) = \
2062 getSimdFunction(name, type, Binary, v->dtype().lanes()); \
2063 } break;
2064 SIMD_BINARY_MATH_CASE(kAtan2, "atan2f", FloatTy_)
2065 SIMD_BINARY_MATH_CASE(kPow, "powf", FloatTy_)
2066 SIMD_BINARY_MATH_CASE(kFmod, "fmodf", FloatTy_)
2067 #undef SIMD_BINARY_MATH_CASE
2068
2069 case kRemainder: {
2070 FunctionCallee callee = module_->getOrInsertFunction(
2071 "remainderf",
2072 llvm::FunctionType::get(FloatTy_, {FloatTy_, FloatTy_}, false),
2073 {});
2074 call_ty = callee.getFunctionType();
2075 call_fn = callee.getCallee();
2076 applyMathFunctionAttributes(llvm::cast<llvm::Function>(call_fn));
2077 } break;
2078
2079 default: {
2080 throw unimplemented_lowering(v);
2081 } break;
2082 }
2083
2084 } else if (v->dtype().scalar_type() == ScalarType::Double) {
2085 switch (v->op_type()) {
2086 #define SIMD_UNARY_MATH_CASE(enum, name, type) \
2087 case enum: { \
2088 std::tie(call_ty, call_fn, call_simd_sleef) = \
2089 getSimdFunction(name, type, Unary, v->dtype().lanes()); \
2090 } break;
2091 SIMD_UNARY_MATH_CASE(kLog10, "log10", DoubleTy_)
2092 SIMD_UNARY_MATH_CASE(kLog, "log", DoubleTy_)
2093 SIMD_UNARY_MATH_CASE(kLog1p, "log1p", DoubleTy_)
2094 SIMD_UNARY_MATH_CASE(kLog2, "log2", DoubleTy_)
2095 SIMD_UNARY_MATH_CASE(kExp, "exp", DoubleTy_)
2096 SIMD_UNARY_MATH_CASE(kCos, "cos", DoubleTy_)
2097 SIMD_UNARY_MATH_CASE(kSin, "sin", DoubleTy_)
2098 SIMD_UNARY_MATH_CASE(kSqrt, "sqrt", DoubleTy_)
2099 SIMD_UNARY_MATH_CASE(kAbs, "fabs", DoubleTy_)
2100 SIMD_UNARY_MATH_CASE(kFloor, "floor", DoubleTy_)
2101 SIMD_UNARY_MATH_CASE(kCeil, "ceil", DoubleTy_)
2102 SIMD_UNARY_MATH_CASE(kTrunc, "trunc", DoubleTy_)
2103 SIMD_UNARY_MATH_CASE(kRound, "nearbyint", DoubleTy_)
2104 SIMD_UNARY_MATH_CASE(kErf, "erf", DoubleTy_)
2105 SIMD_UNARY_MATH_CASE(kErfc, "erfc", DoubleTy_)
2106 SIMD_UNARY_MATH_CASE(kTan, "tan", DoubleTy_)
2107 SIMD_UNARY_MATH_CASE(kAcos, "acos", DoubleTy_)
2108 SIMD_UNARY_MATH_CASE(kAsin, "asin", DoubleTy_)
2109 SIMD_UNARY_MATH_CASE(kAtan, "atan", DoubleTy_)
2110 SIMD_UNARY_MATH_CASE(kCosh, "cosh", DoubleTy_)
2111 SIMD_UNARY_MATH_CASE(kSinh, "sinh", DoubleTy_)
2112 SIMD_UNARY_MATH_CASE(kTanh, "tanh", DoubleTy_)
2113 SIMD_UNARY_MATH_CASE(kExpm1, "expm1", DoubleTy_)
2114 SIMD_UNARY_MATH_CASE(kLgamma, "lgamma", DoubleTy_)
2115 #undef SIMD_UNARY_MATH_CASE
2116
2117 case kRsqrt: {
2118 v->params().front()->accept(this);
2119 value_ = irb_.CreateUnaryIntrinsic(llvm::Intrinsic::sqrt, value_);
2120 llvm::Value* constant = llvm::ConstantFP::get(DoubleTy_, 1.0);
2121 if (v->dtype().lanes() > 1) {
2122 constant = irb_.CreateVectorSplat(v->dtype().lanes(), constant);
2123 }
2124 value_ = irb_.CreateFDiv(constant, value_);
2125 return;
2126 } break;
2127
2128 #define SIMD_BINARY_MATH_CASE(enum, name, type) \
2129 case enum: { \
2130 std::tie(call_ty, call_fn, call_simd_sleef) = \
2131 getSimdFunction(name, type, Binary, v->dtype().lanes()); \
2132 } break;
2133 SIMD_BINARY_MATH_CASE(kAtan2, "atan2", DoubleTy_)
2134 SIMD_BINARY_MATH_CASE(kPow, "pow", DoubleTy_)
2135 SIMD_BINARY_MATH_CASE(kFmod, "fmod", DoubleTy_)
2136 #undef SIMD_BINARY_MATH_CASE
2137
2138 case kRemainder: {
2139 FunctionCallee callee = module_->getOrInsertFunction(
2140 "remainder",
2141 llvm::FunctionType::get(DoubleTy_, {DoubleTy_, DoubleTy_}, false),
2142 {});
2143 call_ty = callee.getFunctionType();
2144 call_fn = callee.getCallee();
2145 applyMathFunctionAttributes(llvm::cast<llvm::Function>(call_fn));
2146 } break;
2147
2148 default: {
2149 throw unimplemented_lowering(v);
2150 } break;
2151 }
2152 } else if (v->dtype().is_integral() && v->op_type() == kAbs) {
2153 // abs is only intrinsic defined for integer inputs in pytorch eager
2154 v->params().front()->accept(this);
2155 if (!v->dtype().is_signed()) {
2156 return;
2157 }
2158 // TODO: use llvm.abs intrinsic for LLVM 12
2159 auto zero = llvm::ConstantInt::get(value_->getType(), 0);
2160 auto neg_value = irb_.CreateSub(zero, value_);
2161 auto icmp = irb_.CreateICmpSGT(value_, zero);
2162 value_ = irb_.CreateSelect(icmp, value_, neg_value);
2163 return;
2164 } else {
2165 TORCH_INTERNAL_ASSERT(
2166 false,
2167 buildErrorMessage(
2168 std::string("Unimplemented lowering for intrinsic '") +
2169 std::to_string(v->op_type()) + "' for input of dtype " +
2170 std::to_string(v->dtype().scalar_dtype()) +
2171 " in LLVM codegen of the fuser."));
2172 }
2173
2174 std::vector<llvm::Value*> params;
2175 for (auto& p : v->params()) {
2176 p->accept(this);
2177 params.push_back(value_);
2178 }
2179
2180 if (v->dtype().lanes() == 1 || call_simd_sleef == true) {
2181 value_ = irb_.CreateCall(call_ty, call_fn, params);
2182 } else {
2183 llvm::Type* vecType = params[0]->getType();
2184 value_ = llvm::UndefValue::get(vecType);
2185 for (int i = 0; i < v->dtype().lanes(); ++i) {
2186 std::vector<llvm::Value*> call_operands;
2187 for (auto p : params) {
2188 call_operands.push_back(irb_.CreateExtractElement(p, i));
2189 }
2190
2191 llvm::Value* val = irb_.CreateCall(call_ty, call_fn, call_operands);
2192 value_ = irb_.CreateInsertElement(value_, val, i);
2193 }
2194 }
2195 }
2196
handleBufReuse(BufPtr buf,BufPtr buf_to_reuse)2197 void LLVMCodeGenImpl::handleBufReuse(BufPtr buf, BufPtr buf_to_reuse) {
2198 llvm::Value* ptr = varToVal_.at(buf_to_reuse->base_handle());
2199 if (buf_to_reuse->dtype().scalar_type() != buf->dtype().scalar_type()) {
2200 ptr = irb_.CreatePointerCast(ptr, dtypeToLLVMPtr(buf->dtype()));
2201 }
2202 varToVal_[buf->base_handle()] = ptr;
2203 }
2204
visit(const ExternalCallPtr & v)2205 void LLVMCodeGenImpl::visit(const ExternalCallPtr& v) {
2206 auto& func_registry = getNNCFunctionRegistry();
2207 if (!func_registry.count(v->func_name())) {
2208 throw unimplemented_lowering(v);
2209 }
2210
2211 // Prepare a vector of bufs that we need to pass to the external function.
2212 // This vector is the output buf followed by the buf_args.
2213 std::vector<BufPtr> bufs(v->buf_args());
2214 bufs.insert(bufs.begin(), v->buf());
2215
2216 int64_t bufs_num = bufs.size();
2217 int64_t args_num = v->args().size();
2218
2219 // Count the size of dims array - it consists of dimension of all bufs
2220 // concatenated together.
2221 int64_t dims_num = 0;
2222 for (BufPtr b : bufs) {
2223 dims_num += b->dims().size();
2224 }
2225 #if LLVM_VERSION_MAJOR >= 15
2226 llvm::Value* buf_ptrs = irb_.CreateAlloca(
2227 OpqPtrTy_, llvm::ConstantInt::getSigned(IntTy_, bufs_num));
2228 #else
2229 llvm::Value* buf_ptrs = irb_.CreateAlloca(
2230 Int8PtrTy_, llvm::ConstantInt::getSigned(IntTy_, bufs_num));
2231 #endif
2232 llvm::Value* buf_ranks = irb_.CreateAlloca(
2233 LongTy_, llvm::ConstantInt::getSigned(IntTy_, bufs_num));
2234 llvm::Value* buf_dims = irb_.CreateAlloca(
2235 LongTy_, llvm::ConstantInt::getSigned(IntTy_, dims_num));
2236 llvm::Value* buf_strides = irb_.CreateAlloca(
2237 LongTy_, llvm::ConstantInt::getSigned(IntTy_, dims_num));
2238 llvm::Value* buf_dtypes = irb_.CreateAlloca(
2239 ByteTy_, llvm::ConstantInt::getSigned(IntTy_, bufs_num));
2240 llvm::Value* extra_args = irb_.CreateAlloca(
2241 LongTy_, llvm::ConstantInt::getSigned(IntTy_, args_num));
2242
2243 int i = 0;
2244 int dim_idx = 0;
2245 int stride_idx = 0;
2246 for (BufPtr b : bufs) {
2247 // Store value for buf pointer
2248 b->base_handle()->accept(this);
2249 auto buf_ptr = this->value_;
2250 #if LLVM_VERSION_MAJOR >= 15
2251 auto gep = irb_.CreateInBoundsGEP(
2252 OpqPtrTy_, buf_ptrs, llvm::ConstantInt::getSigned(IntTy_, i));
2253 auto buf_void_ptr = irb_.CreatePointerCast(buf_ptr, OpqPtrTy_);
2254 #else
2255 auto gep = irb_.CreateInBoundsGEP(
2256 Int8PtrTy_, buf_ptrs, llvm::ConstantInt::getSigned(IntTy_, i));
2257 auto buf_void_ptr = irb_.CreatePointerCast(buf_ptr, Int8PtrTy_);
2258 #endif
2259 irb_.CreateStore(buf_void_ptr, gep);
2260
2261 // Store dtype of the buf
2262 gep = irb_.CreateInBoundsGEP(
2263 ByteTy_, buf_dtypes, llvm::ConstantInt::getSigned(IntTy_, i));
2264 irb_.CreateStore(
2265 llvm::ConstantInt::getSigned(ByteTy_, (int8_t)b->dtype().scalar_type()),
2266 gep);
2267
2268 // Store rank of the buf
2269 gep = irb_.CreateInBoundsGEP(
2270 LongTy_, buf_ranks, llvm::ConstantInt::getSigned(IntTy_, i));
2271 irb_.CreateStore(
2272 llvm::ConstantInt::getSigned(LongTy_, b->dims().size()), gep);
2273
2274 // Store dims of the buf
2275 for (const auto dim : c10::irange(b->dims().size())) {
2276 gep = irb_.CreateInBoundsGEP(
2277 LongTy_, buf_dims, llvm::ConstantInt::getSigned(IntTy_, dim_idx));
2278 b->dims()[dim]->accept(this);
2279 auto dim_val = this->value_;
2280 irb_.CreateStore(irb_.CreateZExt(dim_val, LongTy_), gep);
2281 dim_idx++;
2282 }
2283
2284 // Store strides of the buf
2285 for (const auto dim : c10::irange(b->dims().size())) {
2286 gep = irb_.CreateInBoundsGEP(
2287 LongTy_,
2288 buf_strides,
2289 llvm::ConstantInt::getSigned(IntTy_, stride_idx));
2290 b->strides()[dim]->accept(this);
2291 auto stride_val = this->value_;
2292 irb_.CreateStore(irb_.CreateZExt(stride_val, LongTy_), gep);
2293 stride_idx++;
2294 }
2295
2296 i++;
2297 }
2298
2299 i = 0;
2300 for (ExprPtr arg : v->args()) {
2301 auto gep = irb_.CreateInBoundsGEP(
2302 LongTy_, extra_args, llvm::ConstantInt::getSigned(IntTy_, i));
2303 arg->accept(this);
2304 irb_.CreateStore(irb_.CreateZExtOrBitCast(this->value_, LongTy_), gep);
2305 i++;
2306 }
2307
2308 // Generate the call itself
2309 std::string fname = v->func_name();
2310 #if LLVM_VERSION_MAJOR >= 15
2311 FunctionCallee callee = module_->getOrInsertFunction(
2312 fname,
2313 llvm::FunctionType::get(
2314 llvm::Type::getVoidTy(getContext()), // return type
2315 {LongTy_, // int64_t bufs_num
2316 OpqPtrTy_, // void** buf_data
2317 OpqPtrTy_, // int64_t* buf_ranks
2318 OpqPtrTy_, // int64_t* buf_dims
2319 OpqPtrTy_, // int64_t* buf_strides
2320 OpqPtrTy_, // int64_t* buf_dtypes
2321 LongTy_, // int64_t args_num
2322 OpqPtrTy_}, // int64_t* extra_args
2323 false)); // is var_arg
2324 #else
2325 FunctionCallee callee = module_->getOrInsertFunction(
2326 fname,
2327 llvm::FunctionType::get(
2328 llvm::Type::getVoidTy(getContext()), // return type
2329 {LongTy_, // int64_t bufs_num
2330 Int8PtrTy_->getPointerTo(), // void** buf_data
2331 LongTy_->getPointerTo(), // int64_t* buf_ranks
2332 LongTy_->getPointerTo(), // int64_t* buf_dims
2333 LongTy_->getPointerTo(), // int64_t* buf_strides
2334 ByteTy_->getPointerTo(), // int64_t* buf_dtypes
2335 LongTy_, // int64_t args_num
2336 LongTy_->getPointerTo()}, // int64_t* extra_args
2337 false)); // is var_arg
2338 #endif
2339
2340 auto call_ty = callee.getFunctionType();
2341 auto call_fn = callee.getCallee();
2342 llvm::cast<llvm::Function>(call_fn)->addFnAttr(llvm::Attribute::NoUnwind);
2343
2344 irb_.CreateCall(
2345 call_ty,
2346 call_fn,
2347 {llvm::ConstantInt::getSigned(LongTy_, bufs_num),
2348 buf_ptrs,
2349 buf_ranks,
2350 buf_dims,
2351 buf_strides,
2352 buf_dtypes,
2353 llvm::ConstantInt::getSigned(LongTy_, args_num),
2354 extra_args});
2355
2356 value_ = llvm::ConstantInt::get(IntTy_, 0);
2357 }
2358
visit(const ExternalCallWithAllocPtr & v)2359 void LLVMCodeGenImpl::visit(const ExternalCallWithAllocPtr& v) {
2360 auto& func_registry = getNNCFunctionRegistry();
2361 if (!func_registry.count(v->func_name())) {
2362 throw unimplemented_lowering(v);
2363 }
2364
2365 const auto& bufs_out = v->buf_out_args();
2366 const auto& bufs_in = v->buf_args();
2367
2368 const auto bufs_in_size = bufs_in.size();
2369 const auto bufs_out_size = bufs_out.size();
2370 const auto args_num = v->args().size();
2371
2372 // Count the size of dims array - it consists of dimension of all bufs
2373 // concatenated together.
2374 size_t dims_num = 0;
2375 for (const auto& b : bufs_in) {
2376 dims_num += b->dims().size();
2377 }
2378
2379 // bufs_out_size for out tensors data pointers
2380 // bufs_in_size for input pointers
2381 // bufs_out_size for out tensors TensorImpl* to pass to nnc_aten_free to
2382 // release out tensors
2383 #if LLVM_VERSION_MAJOR >= 15
2384 llvm::Value* buf_ptrs = irb_.CreateAlloca(
2385 OpqPtrTy_,
2386 llvm::ConstantInt::getSigned(IntTy_, bufs_in_size + 2 * bufs_out_size));
2387 #else
2388 llvm::Value* buf_ptrs = irb_.CreateAlloca(
2389 Int8PtrTy_,
2390 llvm::ConstantInt::getSigned(IntTy_, bufs_in_size + 2 * bufs_out_size));
2391 #endif
2392 // @lint-ignore CLANGTIDY
2393 llvm::Value* buf_ranks = irb_.CreateAlloca(
2394 LongTy_, llvm::ConstantInt::getSigned(IntTy_, bufs_in_size));
2395 llvm::Value* buf_dims = irb_.CreateAlloca(
2396 LongTy_, llvm::ConstantInt::getSigned(IntTy_, dims_num));
2397 llvm::Value* buf_strides = irb_.CreateAlloca(
2398 LongTy_, llvm::ConstantInt::getSigned(IntTy_, dims_num));
2399 llvm::Value* buf_dtypes = irb_.CreateAlloca(
2400 ByteTy_, llvm::ConstantInt::getSigned(IntTy_, bufs_in_size));
2401 // @lint-ignore CLANGTIDY
2402 llvm::Value* extra_args = irb_.CreateAlloca(
2403 LongTy_, llvm::ConstantInt::getSigned(IntTy_, args_num));
2404
2405 int i = 0;
2406 int dim_idx = 0;
2407 int stride_idx = 0;
2408 for (const auto& b : bufs_in) {
2409 // Store value for buf pointer
2410 b->base_handle()->accept(this);
2411 auto buf_ptr = this->value_;
2412
2413 #if LLVM_VERSION_MAJOR >= 15
2414 llvm::Value* gep = irb_.CreateInBoundsGEP(
2415 OpqPtrTy_,
2416 buf_ptrs,
2417 // @lint-ignore CLANGTIDY
2418 llvm::ConstantInt::getSigned(IntTy_, bufs_out_size + i));
2419 auto buf_void_ptr = irb_.CreatePointerCast(buf_ptr, OpqPtrTy_);
2420 #else
2421 llvm::Value* gep = irb_.CreateInBoundsGEP(
2422 Int8PtrTy_,
2423 buf_ptrs,
2424 // @lint-ignore CLANGTIDY
2425 llvm::ConstantInt::getSigned(IntTy_, bufs_out_size + i));
2426 auto buf_void_ptr = irb_.CreatePointerCast(buf_ptr, Int8PtrTy_);
2427 #endif
2428
2429 irb_.CreateStore(buf_void_ptr, gep);
2430
2431 // Store dtype of the buf
2432 gep = irb_.CreateInBoundsGEP(
2433 ByteTy_, buf_dtypes, llvm::ConstantInt::getSigned(IntTy_, i));
2434 irb_.CreateStore(
2435 llvm::ConstantInt::getSigned(ByteTy_, (int8_t)b->dtype().scalar_type()),
2436 gep);
2437
2438 // Store rank of the buf
2439 // @lint-ignore CLANGTIDY
2440 gep = irb_.CreateInBoundsGEP(
2441 LongTy_, buf_ranks, llvm::ConstantInt::getSigned(IntTy_, i));
2442 irb_.CreateStore(
2443 llvm::ConstantInt::getSigned(LongTy_, b->dims().size()), gep);
2444
2445 // Store dims of the buf
2446 for (const auto dim : c10::irange(b->dims().size())) {
2447 gep = irb_.CreateInBoundsGEP(
2448 LongTy_, buf_dims, llvm::ConstantInt::getSigned(IntTy_, dim_idx));
2449 b->dims()[dim]->accept(this);
2450 auto dim_val = this->value_;
2451 irb_.CreateStore(irb_.CreateZExt(dim_val, LongTy_), gep);
2452 dim_idx++;
2453 }
2454
2455 // Store strides of the buf
2456 for (const auto dim : c10::irange(b->dims().size())) {
2457 gep = irb_.CreateInBoundsGEP(
2458 LongTy_,
2459 buf_strides,
2460 llvm::ConstantInt::getSigned(IntTy_, stride_idx));
2461 b->strides()[dim]->accept(this);
2462 auto stride_val = this->value_;
2463 irb_.CreateStore(irb_.CreateZExt(stride_val, LongTy_), gep);
2464 stride_idx++;
2465 }
2466
2467 i++;
2468 }
2469
2470 i = 0;
2471 for (const ExprPtr& arg : v->args()) {
2472 auto gep = irb_.CreateInBoundsGEP(
2473 LongTy_, extra_args, llvm::ConstantInt::getSigned(IntTy_, i));
2474 arg->accept(this);
2475 irb_.CreateStore(irb_.CreateZExtOrBitCast(this->value_, LongTy_), gep);
2476 i++;
2477 }
2478
2479 // Generate the call itself
2480 std::string fname = v->func_name();
2481
2482 #if LLVM_VERSION_MAJOR >= 15
2483 FunctionCallee callee = module_->getOrInsertFunction(
2484 fname,
2485 llvm::FunctionType::get(
2486 llvm::Type::getVoidTy(getContext()), // return type
2487 {LongTy_, // int64_t bufs_in_size
2488 OpqPtrTy_, // void** buf_data
2489 OpqPtrTy_, // int64_t* buf_ranks
2490 OpqPtrTy_, // int64_t* buf_dims
2491 OpqPtrTy_, // int64_t* buf_strides
2492 OpqPtrTy_, // int64_t* buf_dtypes
2493 LongTy_, // int64_t args_num
2494 OpqPtrTy_}, // int64_t* extra_args
2495 false)); // is var_arg
2496 #else
2497 FunctionCallee callee = module_->getOrInsertFunction(
2498 fname,
2499 llvm::FunctionType::get(
2500 llvm::Type::getVoidTy(getContext()), // return type
2501 {LongTy_, // int64_t bufs_in_size
2502 Int8PtrTy_->getPointerTo(), // void** buf_data
2503 LongTy_->getPointerTo(), // int64_t* buf_ranks
2504 LongTy_->getPointerTo(), // int64_t* buf_dims
2505 LongTy_->getPointerTo(), // int64_t* buf_strides
2506 ByteTy_->getPointerTo(), // int64_t* buf_dtypes
2507 LongTy_, // int64_t args_num
2508 LongTy_->getPointerTo()}, // int64_t* extra_args
2509 false)); // is var_arg
2510 #endif
2511
2512 auto call_ty = callee.getFunctionType();
2513 auto call_fn = callee.getCallee();
2514 llvm::cast<llvm::Function>(call_fn)->addFnAttr(llvm::Attribute::NoUnwind);
2515
2516 irb_.CreateCall(
2517 call_ty,
2518 call_fn,
2519 // @lint-ignore CLANGTIDY
2520 {llvm::ConstantInt::getSigned(LongTy_, bufs_in_size),
2521 buf_ptrs,
2522 buf_ranks,
2523 buf_dims,
2524 buf_strides,
2525 buf_dtypes,
2526 // @lint-ignore CLANGTIDY
2527 llvm::ConstantInt::getSigned(LongTy_, args_num),
2528 extra_args});
2529
2530 // @lint-ignore CLANGTIDY
2531 for (const auto i : c10::irange(bufs_out_size)) {
2532 const auto& buf_out = bufs_out[i];
2533 #if LLVM_VERSION_MAJOR >= 15
2534 auto gep = irb_.CreateInBoundsGEP(
2535 OpqPtrTy_, buf_ptrs, llvm::ConstantInt::getSigned(IntTy_, i));
2536 llvm::Value* ptr = irb_.CreatePointerCast(
2537 irb_.CreateLoad(OpqPtrTy_, gep), dtypeToLLVMPtr(buf_out->dtype()));
2538 #else
2539 auto gep = irb_.CreateInBoundsGEP(
2540 Int8PtrTy_, buf_ptrs, llvm::ConstantInt::getSigned(IntTy_, i));
2541 llvm::Value* ptr = irb_.CreatePointerCast(
2542 irb_.CreateLoad(Int8PtrTy_, gep), dtypeToLLVMPtr(buf_out->dtype()));
2543 #endif
2544
2545 varToVal_[buf_out->base_handle()] = ptr;
2546
2547 for (auto it = bufsExtAllocReuse_.find(buf_out);
2548 it != bufsExtAllocReuse_.end();
2549 it++) {
2550 auto buf = it->second;
2551 handleBufReuse(buf, buf_out);
2552 }
2553 bufsExtAllocReuse_.erase(buf_out);
2554
2555 #if LLVM_VERSION_MAJOR >= 15
2556 gep = irb_.CreateInBoundsGEP(
2557 OpqPtrTy_,
2558 buf_ptrs,
2559 // @lint-ignore CLANGTIDY
2560 llvm::ConstantInt::getSigned(IntTy_, bufs_out_size + bufs_in_size + i));
2561 bufsExtToFreeVal_[buf_out->base_handle()] = irb_.CreateLoad(OpqPtrTy_, gep);
2562 #else
2563 gep = irb_.CreateInBoundsGEP(
2564 Int8PtrTy_,
2565 buf_ptrs,
2566 // @lint-ignore CLANGTIDY
2567 llvm::ConstantInt::getSigned(IntTy_, bufs_out_size + bufs_in_size + i));
2568 bufsExtToFreeVal_[buf_out->base_handle()] =
2569 irb_.CreateLoad(Int8PtrTy_, gep);
2570 #endif
2571 }
2572
2573 value_ = llvm::ConstantInt::get(IntTy_, 0);
2574 }
2575
visit(const AllocatePtr & v)2576 void LLVMCodeGenImpl::visit(const AllocatePtr& v) {
2577 llvm::Value* size =
2578 llvm::ConstantInt::getSigned(LongTy_, v->dtype().byte_size());
2579 for (ExprPtr e : v->dims()) {
2580 e->accept(this);
2581 size = irb_.CreateMul(size, irb_.CreateZExt(value_, LongTy_));
2582 }
2583
2584 value_ = llvm::ConstantInt::get(IntTy_, 0);
2585
2586 if (llvm::ConstantInt* CI = llvm::dyn_cast<llvm::ConstantInt>(size)) {
2587 if (CI->getSExtValue() < 512) {
2588 llvm::Value* alloca = irb_.CreateAlloca(dtypeToLLVM(v->dtype()), size);
2589 varToVal_[v->buffer_var()] = alloca;
2590 return;
2591 }
2592 }
2593
2594 #if LLVM_VERSION_MAJOR > 17
2595 llvm::Instruction* I = irb_.CreateMalloc(
2596 LongTy_, dtypeToLLVM(v->dtype()), size, nullptr, nullptr, "");
2597 #else
2598 llvm::Instruction* I = llvm::CallInst::CreateMalloc(
2599 irb_.GetInsertBlock(),
2600 LongTy_,
2601 dtypeToLLVM(v->dtype()),
2602 size,
2603 nullptr,
2604 nullptr);
2605 #endif
2606 // Insert the bitcast into the block.
2607 irb_.SetInsertPoint(irb_.GetInsertBlock());
2608 llvm::Value* malloc = irb_.Insert(I);
2609 varToVal_[v->buffer_var()] = malloc;
2610 }
2611
visit(const PlacementAllocatePtr & v)2612 void LLVMCodeGenImpl::visit(const PlacementAllocatePtr& v) {
2613 auto buf_to_reuse = v->buf_to_reuse();
2614 auto buf = v->buf();
2615
2616 if (bufsExtAlloc_.count(buf_to_reuse)) {
2617 bufsExtAllocReuse_.insert({buf_to_reuse, buf});
2618 return;
2619 }
2620
2621 handleBufReuse(buf, buf_to_reuse);
2622 }
2623
visit(const FreePtr & v)2624 void LLVMCodeGenImpl::visit(const FreePtr& v) {
2625 value_ = llvm::ConstantInt::get(IntTy_, 0);
2626
2627 llvm::Value* ptr = bufsExtToFreeVal_.count(v->buffer_var())
2628 ? bufsExtToFreeVal_.at(v->buffer_var())
2629 : varToVal_.at(v->buffer_var());
2630
2631 if (!llvm::isa<llvm::AllocaInst>(ptr)) {
2632 #if LLVM_VERSION_MAJOR > 17
2633 irb_.Insert(irb_.CreateFree(ptr));
2634 #else
2635 irb_.Insert(llvm::CallInst::CreateFree(ptr, irb_.GetInsertBlock()));
2636 #endif
2637 }
2638 }
2639
visit(const FreeExtPtr & v)2640 void LLVMCodeGenImpl::visit(const FreeExtPtr& v) {
2641 value_ = llvm::ConstantInt::get(IntTy_, 0);
2642 const auto& bufs = v->bufs();
2643 const auto bufs_num = bufs.size();
2644
2645 #if LLVM_VERSION_MAJOR >= 15
2646 llvm::Value* ptrs = irb_.CreateAlloca(
2647 OpqPtrTy_, llvm::ConstantInt::getSigned(IntTy_, bufs_num));
2648 #else
2649 llvm::Value* ptrs = irb_.CreateAlloca(
2650 Int8PtrTy_, llvm::ConstantInt::getSigned(IntTy_, bufs_num));
2651 #endif
2652
2653 for (const auto i : c10::irange(bufs_num)) {
2654 const auto& buf = bufs[i];
2655 #if LLVM_VERSION_MAJOR >= 15
2656 llvm::Value* gep = irb_.CreateInBoundsGEP(
2657 OpqPtrTy_, ptrs, llvm::ConstantInt::getSigned(IntTy_, i));
2658 #else
2659 llvm::Value* gep = irb_.CreateInBoundsGEP(
2660 Int8PtrTy_, ptrs, llvm::ConstantInt::getSigned(IntTy_, i));
2661 #endif
2662
2663 auto ptr = bufsExtToFreeVal_[buf->base_handle()];
2664 irb_.CreateStore(ptr, gep);
2665 }
2666
2667 #if LLVM_VERSION_MAJOR >= 15
2668 FunctionCallee callee = module_->getOrInsertFunction(
2669 "nnc_aten_free",
2670 llvm::FunctionType::get(
2671 llvm::Type::getVoidTy(getContext()), // return type
2672 {
2673 LongTy_, // int64_t bufs_num
2674 OpqPtrTy_, // void** ptrs
2675 },
2676 false)); // is var_arg
2677 #else
2678 FunctionCallee callee = module_->getOrInsertFunction(
2679 "nnc_aten_free",
2680 llvm::FunctionType::get(
2681 llvm::Type::getVoidTy(getContext()), // return type
2682 {
2683 LongTy_, // int64_t bufs_num
2684 Int8PtrTy_->getPointerTo(), // void** ptrs
2685 },
2686 false)); // is var_arg
2687 #endif
2688
2689 auto call_ty = callee.getFunctionType();
2690 auto call_fn = callee.getCallee();
2691 llvm::cast<llvm::Function>(call_fn)->addFnAttr(llvm::Attribute::NoUnwind);
2692
2693 irb_.CreateCall(
2694 call_ty,
2695 call_fn,
2696 {llvm::ConstantInt::getSigned(LongTy_, bufs_num), ptrs});
2697
2698 value_ = llvm::ConstantInt::get(IntTy_, 0);
2699 }
2700
visit(const LetPtr & v)2701 void LLVMCodeGenImpl::visit(const LetPtr& v) {
2702 v->value()->accept(this);
2703 if (!varToVal_.count(v->var())) {
2704 varToVal_.emplace(v->var(), value_);
2705 scopeToVar_[scope_].push_back(v->var());
2706 } else {
2707 throw std::runtime_error("var should not exist before");
2708 }
2709 }
2710
visit(const CondPtr & v)2711 void LLVMCodeGenImpl::visit(const CondPtr& v) {
2712 // Even if true_stmt and false_stmt are nullptr,
2713 // in case condition is a function call with side effect,
2714 // we still evaluate it.
2715 v->condition()->accept(this);
2716
2717 if (!v->true_stmt() && !v->false_stmt()) {
2718 return;
2719 }
2720 assert(v->true_stmt());
2721
2722 llvm::Value* condition = value_;
2723 llvm::Value* c = irb_.CreateICmpNE(
2724 condition, llvm::ConstantInt::get(condition->getType(), 0));
2725 llvm::BasicBlock* then_block =
2726 llvm::BasicBlock::Create(getContext(), "then", fn_);
2727 llvm::BasicBlock* else_block = nullptr;
2728 if (v->false_stmt()) {
2729 else_block = llvm::BasicBlock::Create(getContext(), "else", fn_);
2730 }
2731 llvm::BasicBlock* end_block =
2732 llvm::BasicBlock::Create(getContext(), "end", fn_);
2733
2734 if (else_block) {
2735 irb_.CreateCondBr(c, then_block, else_block);
2736 } else {
2737 irb_.CreateCondBr(c, then_block, end_block);
2738 }
2739
2740 irb_.SetInsertPoint(then_block);
2741 v->true_stmt()->accept(this);
2742 irb_.CreateBr(end_block);
2743
2744 if (else_block) {
2745 irb_.SetInsertPoint(else_block);
2746 v->false_stmt()->accept(this);
2747 irb_.CreateBr(end_block);
2748 }
2749
2750 irb_.SetInsertPoint(end_block);
2751 }
2752
2753 // "New" PassManager needed to replace TM.adjustPassManager
2754 #if LLVM_VERSION_MAJOR >= 15
optimize(llvm::Module & M)2755 void LLVMCodeGenImpl::optimize(llvm::Module& M) {
2756 // Add internal analysis passes from the target machine.
2757 auto& TM = jit_->getTargetMachine();
2758
2759 // Create the analysis managers.
2760 llvm::LoopAnalysisManager LAM;
2761 llvm::FunctionAnalysisManager FAM;
2762 llvm::CGSCCAnalysisManager CGAM;
2763 llvm::ModuleAnalysisManager MAM;
2764
2765 // Create the new pass manager builder.
2766 // Take a look at the PassBuilder constructor parameters for more
2767 // customization, e.g. specifying a TargetMachine or various debugging
2768 // options.
2769 llvm::PassBuilder PB(&TM);
2770
2771 #if LLVM_VERSION_MAJOR >= 18 && LLVM_VERSION_MAJOR < 19
2772 TM.registerPassBuilderCallbacks(PB, false);
2773 #else
2774 TM.registerPassBuilderCallbacks(PB);
2775 #endif
2776
2777 // Register all the basic analyses with the managers.
2778 PB.registerModuleAnalyses(MAM);
2779 PB.registerCGSCCAnalyses(CGAM);
2780 PB.registerFunctionAnalyses(FAM);
2781 PB.registerLoopAnalyses(LAM);
2782 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
2783
2784 llvm::ModulePassManager MPM =
2785 PB.buildPerModuleDefaultPipeline(llvm::OptimizationLevel::O3);
2786 llvm::FunctionPassManager FPM = PB.buildFunctionSimplificationPipeline(
2787 llvm::OptimizationLevel::O3, llvm::ThinOrFullLTOPhase::None);
2788
2789 FAM.registerPass([&] { return TM.getTargetIRAnalysis(); });
2790
2791 FPM.addPass(llvm::LoopVectorizePass());
2792 FPM.addPass(llvm::SLPVectorizerPass());
2793
2794 FPM.addPass(llvm::DCEPass());
2795 MPM.addPass(llvm::AlwaysInlinerPass());
2796
2797 MPM.run(M, MAM);
2798 for (auto& FF : M) {
2799 if (!FF.empty()) {
2800 FPM.run(FF, FAM);
2801 }
2802 }
2803 }
2804 #else // "Old" PassManager
optimize(llvm::Module & M)2805 void LLVMCodeGenImpl::optimize(llvm::Module& M) {
2806 llvm::legacy::FunctionPassManager FPM(&M);
2807 llvm::legacy::PassManager PM;
2808
2809 // Add internal analysis passes from the target machine.
2810 auto& TM = jit_->getTargetMachine();
2811 PM.add(llvm::createTargetTransformInfoWrapperPass(TM.getTargetIRAnalysis()));
2812 FPM.add(llvm::createTargetTransformInfoWrapperPass(TM.getTargetIRAnalysis()));
2813
2814 llvm::PassManagerBuilder PMB;
2815 PMB.OptLevel = 3;
2816 PMB.LoopVectorize = true;
2817 PMB.SLPVectorize = true;
2818 TM.adjustPassManager(PMB);
2819
2820 PMB.populateFunctionPassManager(FPM);
2821 PMB.populateModulePassManager(PM);
2822 FPM.doInitialization();
2823 PM.add(llvm::createDeadCodeEliminationPass());
2824 PM.add(llvm::createAlwaysInlinerLegacyPass());
2825 PM.run(M);
2826 for (auto& FF : M) {
2827 FPM.run(FF);
2828 }
2829 FPM.doFinalization();
2830 }
2831 #endif
2832
2833 RegisterCodeGen<LLVMCodeGen> llvm_codegen_reg("llvm_codegen");
2834
2835 #endif // TORCH_ENABLE_LLVM
2836