xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/codegen.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <torch/csrc/jit/tensorexpr/ir.h>
5 #include <torch/csrc/jit/tensorexpr/tensor.h>
6 
7 #include <utility>
8 
9 namespace torch::jit::tensorexpr {
10 
11 template <typename T>
12 class PaddedBuffer;
13 
14 class TORCH_API CodeGen {
15  public:
16   class BufferArg;
17   class CallArg;
18 
19   template <typename... Ts>
CodeGen(StmtPtr stmt,Ts...ts)20   CodeGen(StmtPtr stmt, Ts... ts)
21       : stmt_(std::move(stmt)), buffer_args_({BufferArg(ts)...}) {}
22 
23   CodeGen(
24       StmtPtr stmt,
25       std::vector<BufferArg> buffer_args,
26       at::Device device = at::kCPU,
27       std::string kernel_func_name = "func");
28 
29   virtual ~CodeGen() = default;
30 
stmt()31   StmtPtr stmt() const {
32     return stmt_;
33   }
34 
set_stmt(StmtPtr s)35   void set_stmt(StmtPtr s) {
36     stmt_ = std::move(s);
37   }
38 
apply_mutator(IRMutator * mutator)39   void apply_mutator(IRMutator* mutator) {
40     stmt_ = stmt_->accept_mutator(mutator);
41   }
42 
apply_visitor(IRVisitor * visitor)43   void apply_visitor(IRVisitor* visitor) {
44     stmt_->accept(visitor);
45   }
46 
buffer_args()47   std::vector<BufferArg>& buffer_args() {
48     return buffer_args_;
49   }
50 
buffer_args()51   const std::vector<BufferArg>& buffer_args() const {
52     return buffer_args_;
53   }
54 
device()55   at::Device device() {
56     return device_;
57   }
58 
59   // This function returns the generated code as
60   // a string.
61   virtual std::string getCodeText(
62       const std::string& attr [[maybe_unused]] = "") {
63     return "";
64   }
65 
66   // TODO: Figure out how to unify these call interfaces.
67 
68   /// Call a function with a vector of CallArgs, which are tagged
69   /// unions that properly type the arguments.
70   virtual void call(const std::vector<CallArg>& args) = 0;
71 
72   /// Call a function faster than a regular `call` by assuming that
73   /// the generated kernel already knows the type of the arguments, so
74   /// they can be type-punned with `void*`s.
75   virtual void call_raw(const std::vector<void*>& args) = 0;
76 
77   /// Call a function even faster than a regular call, by assuming
78   /// that the number of thread blocks can be derived from `numel` via
79   /// a simple division, rather than evaluating an expression.
80   virtual void call_with_numel(void** args, int64_t numel);
81 
empty_strided(c10::IntArrayRef size,c10::IntArrayRef stride,std::optional<c10::ScalarType> dtype_opt,std::optional<c10::Layout> layout_opt,std::optional<c10::Device> device_opt,std::optional<bool> pin_memory_opt)82   virtual at::Tensor empty_strided(
83       c10::IntArrayRef size,
84       c10::IntArrayRef stride,
85       std::optional<c10::ScalarType> dtype_opt,
86       std::optional<c10::Layout> layout_opt,
87       std::optional<c10::Device> device_opt,
88       std::optional<bool> pin_memory_opt) {
89     return at::empty_strided(
90         size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
91   }
92 
kernel_func_name()93   const std::string& kernel_func_name() const {
94     return kernel_func_name_;
95   }
96 
97   void allocIntermediateBufs();
98 
99  protected:
100   static void* argToPtr(const BufferArg& bufferArg, const CallArg& callArg);
101 
102  private:
103   StmtPtr stmt_;
104   std::vector<BufferArg> buffer_args_;
105   at::Device device_ = at::kCPU;
106   std::string kernel_func_name_ = "func";
107 };
108 
109 class TORCH_API ExtCallMemoryReuse : public IRMutator {
110   static std::unordered_map<std::string, std::string> makeExtCallFuncNameMap();
111   static const std::unordered_map<std::string, std::string> extCallFuncNameMap_;
112 
113  public:
114   explicit ExtCallMemoryReuse(
115       const std::vector<CodeGen::BufferArg>& bufferArgs);
116   ~ExtCallMemoryReuse() override = default;
117   StmtPtr mutate(const ExternalCallPtr& v) override;
118 
119  private:
120   std::unordered_set<BufPtr> bufferArgs_;
121 };
122 
123 class CodeGen::BufferArg {
124  public:
BufferArg(const Tensor & tensor)125   BufferArg(const Tensor& tensor) : buf_(tensor.buf()) {}
BufferArg(const VarHandle & var)126   BufferArg(const VarHandle& var) : var_(var.node()), isVar_(true) {}
BufferArg(const BufHandle & buf)127   BufferArg(const BufHandle& buf) : buf_(buf.node()) {}
BufferArg(BufPtr buf)128   BufferArg(BufPtr buf) : buf_(std::move(buf)) {}
129 
var()130   VarPtr var() const {
131     return isVar_ ? var_ : buf_->base_handle();
132   }
133 
buf()134   BufPtr buf() const {
135     return buf_;
136   }
137 
isVar()138   bool isVar() const {
139     return isVar_;
140   }
141 
dtype()142   Dtype dtype() const {
143     return isVar_ ? var_->dtype() : buf_->dtype();
144   }
145 
146  private:
147   VarPtr var_ = nullptr;
148   BufPtr buf_ = nullptr;
149   bool isVar_ = false;
150 };
151 
152 class CodeGen::CallArg {
153  public:
154   template <typename T>
155   CallArg(const PaddedBuffer<T>& buffer);
156 
157   template <typename T>
CallArg(const std::vector<T> & buffer)158   CallArg(const std::vector<T>& buffer)
159       : data_(const_cast<T*>(buffer.data())) {}
160 
CallArg(void * ptr)161   CallArg(void* ptr) : data_(ptr) {}
162 
163 #define ARG_TYPE_CTOR(Type, Name)      \
164   CallArg(Type v) {                    \
165     memcpy(buffer_, &v, sizeof(Type)); \
166     data_ = (void*)buffer_;            \
167   }
168   AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ARG_TYPE_CTOR);
169 #undef ARG_TYPE_CTOR
170 
data()171   void* data() const {
172     return data_;
173   }
174 
CallArg(const CallArg & rhs)175   CallArg(const CallArg& rhs) {
176     if (rhs.data_ == rhs.buffer_) {
177       memcpy(this->buffer_, rhs.buffer_, sizeof(rhs.buffer_));
178       this->data_ = (void*)(this->buffer_);
179     } else {
180       this->data_ = rhs.data_;
181     }
182   }
183 
184   CallArg& operator=(const CallArg& rhs) {
185     if (this == &rhs) {
186       return *this;
187     }
188     if (rhs.data_ == rhs.buffer_) {
189       memcpy(this->buffer_, rhs.buffer_, sizeof(rhs.buffer_));
190       this->data_ = (void*)(this->buffer_);
191     } else {
192       this->data_ = rhs.data_;
193     }
194     return *this;
195   }
196 
197 #define ARG_PTR_DEFINE(Type, Name)                  \
198   Type* Name##Ptr() const {                         \
199     TORCH_INTERNAL_ASSERT(data_ == (void*)buffer_); \
200     return (Type*)data_;                            \
201   }
202   AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ARG_PTR_DEFINE);
203 #undef ARG_PTR_DEFINE
204 
205  private:
206   void* data_;
207   // Regarding a scalar value, CallArg uses void**=&data_ to store it. But the
208   // bit width of a pointer is 32bit on a 32bit platform. It cannot store the
209   // scalar if the bit width of the scalar is larger than 32bit, such as double
210   // and long. Hence, we add 8 bytes buffer dedicated to storing the scalar
211   // value regardless its bit width is less or greater than 32bits.
212   char buffer_[8] = {0}; // 64bits
213 };
214 
215 class RegisterCodeGenList {
216  public:
217   TORCH_API static RegisterCodeGenList& GetInstance();
218 
219   using StmtFactoryMethod = std::function<std::unique_ptr<CodeGen>(
220       StmtPtr stmt,
221       const std::vector<CodeGen::BufferArg>&,
222       at::Device device,
223       const std::string& kernel_func_name)>;
224 
225   TORCH_API StmtFactoryMethod FindStmtFactoryMethod(const std::string& name);
226   RegisterCodeGenList(const RegisterCodeGenList&) = delete;
227   RegisterCodeGenList& operator=(const RegisterCodeGenList&) = delete;
228 
229  private:
230   template <class CodeGenType>
231   friend class RegisterCodeGen;
232   RegisterCodeGenList() = default;
233   TORCH_API void AddStmtFactoryMethod(
234       const std::string& name,
235       const StmtFactoryMethod& stmt_factory_method);
236 
237   std::unordered_map<std::string, StmtFactoryMethod> stmt_factory_methods_;
238 };
239 
240 template <class CodeGenType>
241 class RegisterCodeGen {
242  public:
RegisterCodeGen(const std::string & name)243   explicit RegisterCodeGen(const std::string& name) {
244     RegisterCodeGenList& codegen_list = RegisterCodeGenList::GetInstance();
245     codegen_list.AddStmtFactoryMethod(
246         name,
247         [](StmtPtr stmt,
248            const std::vector<CodeGen::BufferArg>& params,
249            at::Device device,
250            const std::string& kernel_func_name) {
251           std::unique_ptr<CodeGen> method(
252               new CodeGenType(stmt, params, device, kernel_func_name));
253           return method;
254         });
255   }
256 };
257 
258 TORCH_API std::unique_ptr<CodeGen> CreateCodeGen(
259     const std::string& name,
260     StmtPtr stmt,
261     const std::vector<CodeGen::BufferArg>& params,
262     at::Device device = at::kCPU,
263     const std::string& kernel_func_name = "func");
264 
265 class TORCH_API GenericIntrinsicsExpander : public IRMutator {
266  protected:
267   ExprPtr mutate(const IntrinsicsPtr& v) override;
268 };
269 
270 } // namespace torch::jit::tensorexpr
271