xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/llvm_codegen.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef TORCH_ENABLE_LLVM
4 #include <torch/csrc/Export.h>
5 
6 #include <torch/csrc/jit/tensorexpr/codegen.h>
7 #include <torch/csrc/jit/tensorexpr/ir.h>
8 #include <torch/csrc/jit/tensorexpr/ir_visitor.h>
9 
10 #include <optional>
11 
12 #include <unordered_map>
13 #include <vector>
14 
15 namespace torch {
16 namespace jit {
17 namespace tensorexpr {
18 
19 class LLVMCodeGenImpl;
20 class LLVMCodeGenCallee;
21 
22 class TORCH_API LLVMCodeGen : public CodeGen {
23  public:
24   explicit LLVMCodeGen(
25       StmtPtr stmt,
26       const std::vector<BufferArg>& args,
27       at::Device device = at::kCPU,
28       const std::string& kernel_func_name = "func",
29       Dtype dtype = kInt,
30       std::optional<std::string> triple = std::nullopt,
31       std::optional<std::string> cpu = std::nullopt,
32       std::optional<std::string> attrs = std::nullopt);
33   explicit LLVMCodeGen(StmtPtr stmt);
34 
35   LLVMCodeGen() = delete;
36   ~LLVMCodeGen() override;
37 
38   // Cleans up all the memory used during LLVM code generation pass except
39   // the generated kernel. After calling this method, users should not call
40   // methods like `getCodeText` that require the LLVMCodeGenImpl data. However,
41   // users can continue to call this kernel using `call` and `call_raw`.
42   void cleanup_memory();
43 
44   TORCH_API void call(const std::vector<CallArg>& args) override;
45   TORCH_API void call_raw(const std::vector<void*>& args) override;
46   TORCH_API void call_with_numel(void** args, int64_t numel) override;
47 
48   at::Tensor empty_strided(
49       c10::IntArrayRef size,
50       c10::IntArrayRef stride,
51       std::optional<c10::ScalarType> dtype_opt,
52       std::optional<c10::Layout> layout_opt,
53       std::optional<c10::Device> device_opt,
54       std::optional<bool> pin_memory_opt) override;
55 
56   template <typename T>
value()57   T value() {
58     return value<T>(nullptr);
59   }
60 
61   template <typename T>
value(std::vector<void * > & args)62   T value(std::vector<void*>& args) {
63     return value<T>(args.data());
64   }
65 
66   template <typename T>
value(void ** args)67   T value(void** args) {
68     T (*fp)(void**) = (T(*)(void**))getKernelAddress(callee_.get());
69     T rv = fp(args);
70     return rv;
71   }
72 
73   std::string getCodeText(const std::string& attr = "") override;
74 
75  private:
76   void* getKernelAddress(LLVMCodeGenCallee* callee);
77 
78   std::unique_ptr<LLVMCodeGenCallee> callee_;
79   std::unique_ptr<LLVMCodeGenImpl> impl_;
80 };
81 
82 struct TORCH_API LLVMCodeGenBuilder {
83   using BufferArg = CodeGen::BufferArg;
84 
LLVMCodeGenBuilderLLVMCodeGenBuilder85   LLVMCodeGenBuilder(StmtPtr stmt, std::vector<BufferArg> args)
86       : stmt_(stmt), args_(std::move(args)) {}
87 
deviceLLVMCodeGenBuilder88   LLVMCodeGenBuilder& device(at::Device device) {
89     device_ = device;
90     return *this;
91   }
92 
kernelFuncNameLLVMCodeGenBuilder93   LLVMCodeGenBuilder& kernelFuncName(std::string name) {
94     kernelFuncName_ = std::move(name);
95     return *this;
96   }
97 
dtypeLLVMCodeGenBuilder98   LLVMCodeGenBuilder& dtype(Dtype d) {
99     dtype_ = d;
100     return *this;
101   }
102 
tripleLLVMCodeGenBuilder103   LLVMCodeGenBuilder& triple(std::string triple) {
104     triple_ = std::move(triple);
105     return *this;
106   }
107 
cpuLLVMCodeGenBuilder108   LLVMCodeGenBuilder& cpu(std::string cpu) {
109     cpu_ = std::move(cpu);
110     return *this;
111   }
112 
attrsLLVMCodeGenBuilder113   LLVMCodeGenBuilder& attrs(std::string attrs) {
114     attrs_ = std::move(attrs);
115     return *this;
116   }
117 
buildLLVMCodeGenBuilder118   std::unique_ptr<LLVMCodeGen> build() {
119     return std::make_unique<LLVMCodeGen>(
120         stmt_, args_, device_, kernelFuncName_, dtype_, triple_, cpu_, attrs_);
121   }
122 
123  private:
124   StmtPtr stmt_;
125   std::vector<BufferArg> args_;
126   at::Device device_ = at::kCPU;
127   std::string kernelFuncName_ = "func";
128   Dtype dtype_ = kInt;
129   std::optional<std::string> triple_ = std::nullopt;
130   std::optional<std::string> cpu_ = std::nullopt;
131   std::optional<std::string> attrs_ = std::nullopt;
132 };
133 
134 TORCH_API std::optional<std::string>& LLVMTargetTriple();
135 TORCH_API std::optional<std::string>& LLVMTargetCPU();
136 TORCH_API std::optional<std::string>& LLVMTargetAttrs();
137 TORCH_API bool& LLVMAOTWorkflow();
138 
139 } // namespace tensorexpr
140 } // namespace jit
141 } // namespace torch
142 
143 #endif // TORCH_ENABLE_LLVM
144