1 #pragma once 2 3 #ifdef TORCH_ENABLE_LLVM 4 #include <torch/csrc/Export.h> 5 6 #include <torch/csrc/jit/tensorexpr/codegen.h> 7 #include <torch/csrc/jit/tensorexpr/ir.h> 8 #include <torch/csrc/jit/tensorexpr/ir_visitor.h> 9 10 #include <optional> 11 12 #include <unordered_map> 13 #include <vector> 14 15 namespace torch { 16 namespace jit { 17 namespace tensorexpr { 18 19 class LLVMCodeGenImpl; 20 class LLVMCodeGenCallee; 21 22 class TORCH_API LLVMCodeGen : public CodeGen { 23 public: 24 explicit LLVMCodeGen( 25 StmtPtr stmt, 26 const std::vector<BufferArg>& args, 27 at::Device device = at::kCPU, 28 const std::string& kernel_func_name = "func", 29 Dtype dtype = kInt, 30 std::optional<std::string> triple = std::nullopt, 31 std::optional<std::string> cpu = std::nullopt, 32 std::optional<std::string> attrs = std::nullopt); 33 explicit LLVMCodeGen(StmtPtr stmt); 34 35 LLVMCodeGen() = delete; 36 ~LLVMCodeGen() override; 37 38 // Cleans up all the memory used during LLVM code generation pass except 39 // the generated kernel. After calling this method, users should not call 40 // methods like `getCodeText` that require the LLVMCodeGenImpl data. However, 41 // users can continue to call this kernel using `call` and `call_raw`. 42 void cleanup_memory(); 43 44 TORCH_API void call(const std::vector<CallArg>& args) override; 45 TORCH_API void call_raw(const std::vector<void*>& args) override; 46 TORCH_API void call_with_numel(void** args, int64_t numel) override; 47 48 at::Tensor empty_strided( 49 c10::IntArrayRef size, 50 c10::IntArrayRef stride, 51 std::optional<c10::ScalarType> dtype_opt, 52 std::optional<c10::Layout> layout_opt, 53 std::optional<c10::Device> device_opt, 54 std::optional<bool> pin_memory_opt) override; 55 56 template <typename T> value()57 T value() { 58 return value<T>(nullptr); 59 } 60 61 template <typename T> value(std::vector<void * > & args)62 T value(std::vector<void*>& args) { 63 return value<T>(args.data()); 64 } 65 66 template <typename T> value(void ** args)67 T value(void** args) { 68 T (*fp)(void**) = (T(*)(void**))getKernelAddress(callee_.get()); 69 T rv = fp(args); 70 return rv; 71 } 72 73 std::string getCodeText(const std::string& attr = "") override; 74 75 private: 76 void* getKernelAddress(LLVMCodeGenCallee* callee); 77 78 std::unique_ptr<LLVMCodeGenCallee> callee_; 79 std::unique_ptr<LLVMCodeGenImpl> impl_; 80 }; 81 82 struct TORCH_API LLVMCodeGenBuilder { 83 using BufferArg = CodeGen::BufferArg; 84 LLVMCodeGenBuilderLLVMCodeGenBuilder85 LLVMCodeGenBuilder(StmtPtr stmt, std::vector<BufferArg> args) 86 : stmt_(stmt), args_(std::move(args)) {} 87 deviceLLVMCodeGenBuilder88 LLVMCodeGenBuilder& device(at::Device device) { 89 device_ = device; 90 return *this; 91 } 92 kernelFuncNameLLVMCodeGenBuilder93 LLVMCodeGenBuilder& kernelFuncName(std::string name) { 94 kernelFuncName_ = std::move(name); 95 return *this; 96 } 97 dtypeLLVMCodeGenBuilder98 LLVMCodeGenBuilder& dtype(Dtype d) { 99 dtype_ = d; 100 return *this; 101 } 102 tripleLLVMCodeGenBuilder103 LLVMCodeGenBuilder& triple(std::string triple) { 104 triple_ = std::move(triple); 105 return *this; 106 } 107 cpuLLVMCodeGenBuilder108 LLVMCodeGenBuilder& cpu(std::string cpu) { 109 cpu_ = std::move(cpu); 110 return *this; 111 } 112 attrsLLVMCodeGenBuilder113 LLVMCodeGenBuilder& attrs(std::string attrs) { 114 attrs_ = std::move(attrs); 115 return *this; 116 } 117 buildLLVMCodeGenBuilder118 std::unique_ptr<LLVMCodeGen> build() { 119 return std::make_unique<LLVMCodeGen>( 120 stmt_, args_, device_, kernelFuncName_, dtype_, triple_, cpu_, attrs_); 121 } 122 123 private: 124 StmtPtr stmt_; 125 std::vector<BufferArg> args_; 126 at::Device device_ = at::kCPU; 127 std::string kernelFuncName_ = "func"; 128 Dtype dtype_ = kInt; 129 std::optional<std::string> triple_ = std::nullopt; 130 std::optional<std::string> cpu_ = std::nullopt; 131 std::optional<std::string> attrs_ = std::nullopt; 132 }; 133 134 TORCH_API std::optional<std::string>& LLVMTargetTriple(); 135 TORCH_API std::optional<std::string>& LLVMTargetCPU(); 136 TORCH_API std::optional<std::string>& LLVMTargetAttrs(); 137 TORCH_API bool& LLVMAOTWorkflow(); 138 139 } // namespace tensorexpr 140 } // namespace jit 141 } // namespace torch 142 143 #endif // TORCH_ENABLE_LLVM 144