1 #pragma once 2 3 #include <ATen/ATen.h> 4 #include <torch/csrc/jit/tensorexpr/ir.h> 5 #include <torch/csrc/jit/tensorexpr/tensor.h> 6 7 #include <utility> 8 9 namespace torch::jit::tensorexpr { 10 11 template <typename T> 12 class PaddedBuffer; 13 14 class TORCH_API CodeGen { 15 public: 16 class BufferArg; 17 class CallArg; 18 19 template <typename... Ts> CodeGen(StmtPtr stmt,Ts...ts)20 CodeGen(StmtPtr stmt, Ts... ts) 21 : stmt_(std::move(stmt)), buffer_args_({BufferArg(ts)...}) {} 22 23 CodeGen( 24 StmtPtr stmt, 25 std::vector<BufferArg> buffer_args, 26 at::Device device = at::kCPU, 27 std::string kernel_func_name = "func"); 28 29 virtual ~CodeGen() = default; 30 stmt()31 StmtPtr stmt() const { 32 return stmt_; 33 } 34 set_stmt(StmtPtr s)35 void set_stmt(StmtPtr s) { 36 stmt_ = std::move(s); 37 } 38 apply_mutator(IRMutator * mutator)39 void apply_mutator(IRMutator* mutator) { 40 stmt_ = stmt_->accept_mutator(mutator); 41 } 42 apply_visitor(IRVisitor * visitor)43 void apply_visitor(IRVisitor* visitor) { 44 stmt_->accept(visitor); 45 } 46 buffer_args()47 std::vector<BufferArg>& buffer_args() { 48 return buffer_args_; 49 } 50 buffer_args()51 const std::vector<BufferArg>& buffer_args() const { 52 return buffer_args_; 53 } 54 device()55 at::Device device() { 56 return device_; 57 } 58 59 // This function returns the generated code as 60 // a string. 61 virtual std::string getCodeText( 62 const std::string& attr [[maybe_unused]] = "") { 63 return ""; 64 } 65 66 // TODO: Figure out how to unify these call interfaces. 67 68 /// Call a function with a vector of CallArgs, which are tagged 69 /// unions that properly type the arguments. 70 virtual void call(const std::vector<CallArg>& args) = 0; 71 72 /// Call a function faster than a regular `call` by assuming that 73 /// the generated kernel already knows the type of the arguments, so 74 /// they can be type-punned with `void*`s. 75 virtual void call_raw(const std::vector<void*>& args) = 0; 76 77 /// Call a function even faster than a regular call, by assuming 78 /// that the number of thread blocks can be derived from `numel` via 79 /// a simple division, rather than evaluating an expression. 80 virtual void call_with_numel(void** args, int64_t numel); 81 empty_strided(c10::IntArrayRef size,c10::IntArrayRef stride,std::optional<c10::ScalarType> dtype_opt,std::optional<c10::Layout> layout_opt,std::optional<c10::Device> device_opt,std::optional<bool> pin_memory_opt)82 virtual at::Tensor empty_strided( 83 c10::IntArrayRef size, 84 c10::IntArrayRef stride, 85 std::optional<c10::ScalarType> dtype_opt, 86 std::optional<c10::Layout> layout_opt, 87 std::optional<c10::Device> device_opt, 88 std::optional<bool> pin_memory_opt) { 89 return at::empty_strided( 90 size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt); 91 } 92 kernel_func_name()93 const std::string& kernel_func_name() const { 94 return kernel_func_name_; 95 } 96 97 void allocIntermediateBufs(); 98 99 protected: 100 static void* argToPtr(const BufferArg& bufferArg, const CallArg& callArg); 101 102 private: 103 StmtPtr stmt_; 104 std::vector<BufferArg> buffer_args_; 105 at::Device device_ = at::kCPU; 106 std::string kernel_func_name_ = "func"; 107 }; 108 109 class TORCH_API ExtCallMemoryReuse : public IRMutator { 110 static std::unordered_map<std::string, std::string> makeExtCallFuncNameMap(); 111 static const std::unordered_map<std::string, std::string> extCallFuncNameMap_; 112 113 public: 114 explicit ExtCallMemoryReuse( 115 const std::vector<CodeGen::BufferArg>& bufferArgs); 116 ~ExtCallMemoryReuse() override = default; 117 StmtPtr mutate(const ExternalCallPtr& v) override; 118 119 private: 120 std::unordered_set<BufPtr> bufferArgs_; 121 }; 122 123 class CodeGen::BufferArg { 124 public: BufferArg(const Tensor & tensor)125 BufferArg(const Tensor& tensor) : buf_(tensor.buf()) {} BufferArg(const VarHandle & var)126 BufferArg(const VarHandle& var) : var_(var.node()), isVar_(true) {} BufferArg(const BufHandle & buf)127 BufferArg(const BufHandle& buf) : buf_(buf.node()) {} BufferArg(BufPtr buf)128 BufferArg(BufPtr buf) : buf_(std::move(buf)) {} 129 var()130 VarPtr var() const { 131 return isVar_ ? var_ : buf_->base_handle(); 132 } 133 buf()134 BufPtr buf() const { 135 return buf_; 136 } 137 isVar()138 bool isVar() const { 139 return isVar_; 140 } 141 dtype()142 Dtype dtype() const { 143 return isVar_ ? var_->dtype() : buf_->dtype(); 144 } 145 146 private: 147 VarPtr var_ = nullptr; 148 BufPtr buf_ = nullptr; 149 bool isVar_ = false; 150 }; 151 152 class CodeGen::CallArg { 153 public: 154 template <typename T> 155 CallArg(const PaddedBuffer<T>& buffer); 156 157 template <typename T> CallArg(const std::vector<T> & buffer)158 CallArg(const std::vector<T>& buffer) 159 : data_(const_cast<T*>(buffer.data())) {} 160 CallArg(void * ptr)161 CallArg(void* ptr) : data_(ptr) {} 162 163 #define ARG_TYPE_CTOR(Type, Name) \ 164 CallArg(Type v) { \ 165 memcpy(buffer_, &v, sizeof(Type)); \ 166 data_ = (void*)buffer_; \ 167 } 168 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ARG_TYPE_CTOR); 169 #undef ARG_TYPE_CTOR 170 data()171 void* data() const { 172 return data_; 173 } 174 CallArg(const CallArg & rhs)175 CallArg(const CallArg& rhs) { 176 if (rhs.data_ == rhs.buffer_) { 177 memcpy(this->buffer_, rhs.buffer_, sizeof(rhs.buffer_)); 178 this->data_ = (void*)(this->buffer_); 179 } else { 180 this->data_ = rhs.data_; 181 } 182 } 183 184 CallArg& operator=(const CallArg& rhs) { 185 if (this == &rhs) { 186 return *this; 187 } 188 if (rhs.data_ == rhs.buffer_) { 189 memcpy(this->buffer_, rhs.buffer_, sizeof(rhs.buffer_)); 190 this->data_ = (void*)(this->buffer_); 191 } else { 192 this->data_ = rhs.data_; 193 } 194 return *this; 195 } 196 197 #define ARG_PTR_DEFINE(Type, Name) \ 198 Type* Name##Ptr() const { \ 199 TORCH_INTERNAL_ASSERT(data_ == (void*)buffer_); \ 200 return (Type*)data_; \ 201 } 202 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ARG_PTR_DEFINE); 203 #undef ARG_PTR_DEFINE 204 205 private: 206 void* data_; 207 // Regarding a scalar value, CallArg uses void**=&data_ to store it. But the 208 // bit width of a pointer is 32bit on a 32bit platform. It cannot store the 209 // scalar if the bit width of the scalar is larger than 32bit, such as double 210 // and long. Hence, we add 8 bytes buffer dedicated to storing the scalar 211 // value regardless its bit width is less or greater than 32bits. 212 char buffer_[8] = {0}; // 64bits 213 }; 214 215 class RegisterCodeGenList { 216 public: 217 TORCH_API static RegisterCodeGenList& GetInstance(); 218 219 using StmtFactoryMethod = std::function<std::unique_ptr<CodeGen>( 220 StmtPtr stmt, 221 const std::vector<CodeGen::BufferArg>&, 222 at::Device device, 223 const std::string& kernel_func_name)>; 224 225 TORCH_API StmtFactoryMethod FindStmtFactoryMethod(const std::string& name); 226 RegisterCodeGenList(const RegisterCodeGenList&) = delete; 227 RegisterCodeGenList& operator=(const RegisterCodeGenList&) = delete; 228 229 private: 230 template <class CodeGenType> 231 friend class RegisterCodeGen; 232 RegisterCodeGenList() = default; 233 TORCH_API void AddStmtFactoryMethod( 234 const std::string& name, 235 const StmtFactoryMethod& stmt_factory_method); 236 237 std::unordered_map<std::string, StmtFactoryMethod> stmt_factory_methods_; 238 }; 239 240 template <class CodeGenType> 241 class RegisterCodeGen { 242 public: RegisterCodeGen(const std::string & name)243 explicit RegisterCodeGen(const std::string& name) { 244 RegisterCodeGenList& codegen_list = RegisterCodeGenList::GetInstance(); 245 codegen_list.AddStmtFactoryMethod( 246 name, 247 [](StmtPtr stmt, 248 const std::vector<CodeGen::BufferArg>& params, 249 at::Device device, 250 const std::string& kernel_func_name) { 251 std::unique_ptr<CodeGen> method( 252 new CodeGenType(stmt, params, device, kernel_func_name)); 253 return method; 254 }); 255 } 256 }; 257 258 TORCH_API std::unique_ptr<CodeGen> CreateCodeGen( 259 const std::string& name, 260 StmtPtr stmt, 261 const std::vector<CodeGen::BufferArg>& params, 262 at::Device device = at::kCPU, 263 const std::string& kernel_func_name = "func"); 264 265 class TORCH_API GenericIntrinsicsExpander : public IRMutator { 266 protected: 267 ExprPtr mutate(const IntrinsicsPtr& v) override; 268 }; 269 270 } // namespace torch::jit::tensorexpr 271