xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/cpp_codegen.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <algorithm>
2 #include <type_traits>
3 #include <utility>
4 #include <vector>
5 
6 #include <torch/csrc/jit/tensorexpr/cpp_codegen.h>
7 #include <torch/csrc/jit/tensorexpr/cpp_intrinsics.h>
8 #include <torch/csrc/jit/tensorexpr/external_functions_registry.h>
9 #include <torch/csrc/jit/tensorexpr/types.h>
10 
11 namespace torch::jit::tensorexpr {
12 
13 // Rewrites the variables' name according to valid C++ naming convention.
14 // E.g. in Graph IR, variable name may contain '.', in C++, they are replaced
15 // with '_'.
16 class CppVarNameRewriter : public IRVisitor {
17  public:
visit(const VarPtr & v)18   void visit(const VarPtr& v) override {
19     constexpr char kDot = '.';
20     constexpr char kUnderscore = '_';
21     if (v->name_hint().find(kDot) == std::string::npos) {
22       return;
23     }
24     std::string name = v->name_hint();
25     std::replace(name.begin(), name.end(), kDot, kUnderscore);
26     v->set_name_hint(std::move(name));
27   }
28 
visit(const BufPtr & v)29   void visit(const BufPtr& v) override {
30     v->base_handle()->accept(this);
31   }
32 };
33 
declareExternalFunction(const std::string & func_name)34 static std::string declareExternalFunction(const std::string& func_name) {
35   return "void " + func_name +
36       "("
37       "int64_t bufs_num, "
38       "void** buf_data, "
39       "int64_t* buf_ranks, "
40       "int64_t* buf_dims, "
41       "int8_t* buf_dtypes, "
42       "int64_t args_num, "
43       "int64_t* extra_args);";
44 }
45 
CppPrinter(std::ostream * os)46 CppPrinter::CppPrinter(std::ostream* os) : IRPrinter(*os), lane_(0) {}
47 
48 CppPrinter::~CppPrinter() = default;
49 
printPrologue()50 void CppPrinter::printPrologue() {
51   os() << "#include <cassert>" << '\n';
52   os() << "#include <cmath>" << '\n';
53   os() << "#include <algorithm>" << '\n';
54   os() << "#include <type_traits>" << '\n';
55   os() << '\n';
56 
57   os() << "#define POS_INFINITY INFINITY" << '\n';
58   os() << "#define NEG_INFINITY -INFINITY" << '\n';
59   os() << '\n';
60 
61   os() << cpp_intrinsics_definition << '\n';
62   os() << '\n';
63 
64   os() << "namespace torch {" << '\n';
65   os() << "namespace jit {" << '\n';
66   os() << "namespace tensorexpr {" << '\n';
67   for (auto const& it : getNNCFunctionRegistry()) {
68     os() << declareExternalFunction(it.first) << '\n';
69   }
70   os() << "} // namespace tensorexpr" << '\n';
71   os() << "} // namespace jit" << '\n';
72   os() << "} // namespace torch" << '\n';
73   os() << '\n';
74 
75   os() << "using namespace torch::jit::tensorexpr;" << '\n';
76   os() << '\n';
77 }
78 
79 template <typename T>
visit_mod(std::ostream & os,const ExprPtr & lhs,const ExprPtr & rhs)80 inline std::enable_if_t<!std::is_floating_point_v<T>, void> visit_mod(
81     std::ostream& os,
82     const ExprPtr& lhs,
83     const ExprPtr& rhs) {
84   os << *lhs << " % " << *rhs;
85 }
86 
87 template <typename T>
visit_mod(std::ostream & os,const ExprPtr & lhs,const ExprPtr & rhs)88 inline std::enable_if_t<std::is_floating_point_v<T>, void> visit_mod(
89     std::ostream& os,
90     const ExprPtr& lhs,
91     const ExprPtr& rhs) {
92   os << "std::fmod(" << *lhs << ", " << *rhs << ")";
93 }
94 
95 template <typename T>
96 inline std::
97     enable_if_t<std::is_floating_point_v<T> || std::is_integral_v<T>, void>
visit_max(std::ostream & os,const ExprPtr & lhs,const ExprPtr & rhs)98     visit_max(std::ostream& os, const ExprPtr& lhs, const ExprPtr& rhs) {
99   os << "std::max(" << *lhs << ", " << *rhs << ")";
100 }
101 
102 template <typename T>
103 inline std::
104     enable_if_t<!std::is_floating_point_v<T> && !std::is_integral_v<T>, void>
visit_max(std::ostream & os,const ExprPtr & lhs,const ExprPtr & rhs)105     visit_max(std::ostream& os, const ExprPtr& lhs, const ExprPtr& rhs) {
106   os << "(" << *lhs << " < " << *rhs << ") ? " << *rhs << " : " << *lhs;
107 }
108 
109 template <typename T>
110 inline std::
111     enable_if_t<std::is_floating_point_v<T> || std::is_integral_v<T>, void>
visit_min(std::ostream & os,const ExprPtr & lhs,const ExprPtr & rhs)112     visit_min(std::ostream& os, const ExprPtr& lhs, const ExprPtr& rhs) {
113   os << "std::min(" << *lhs << ", " << *rhs << ")";
114 }
115 
116 template <typename T>
117 inline std::
118     enable_if_t<!std::is_floating_point_v<T> && !std::is_integral_v<T>, void>
visit_min(std::ostream & os,const ExprPtr & lhs,const ExprPtr & rhs)119     visit_min(std::ostream& os, const ExprPtr& lhs, const ExprPtr& rhs) {
120   os << *lhs << " < " << *rhs << " ? " << *lhs << " : " << *rhs;
121 }
122 
123 template <typename T>
visit_binary_op(std::ostream & os,const ExprPtr & lhs,const ExprPtr & rhs,IRNodeType op_type)124 void visit_binary_op(
125     std::ostream& os,
126     const ExprPtr& lhs,
127     const ExprPtr& rhs,
128     IRNodeType op_type) {
129   switch (op_type) {
130     case IRNodeType::kMod:
131       visit_mod<T>(os, lhs, rhs);
132       break;
133     case IRNodeType::kMax:
134       visit_max<T>(os, lhs, rhs);
135       break;
136     case IRNodeType::kMin:
137       visit_min<T>(os, lhs, rhs);
138       break;
139     default:
140       throw std::runtime_error("invalid op type");
141   }
142 }
143 
144 template <typename Op>
dispatch_binary_op(std::ostream & os,const BinaryOpNode<Op> * v)145 void dispatch_binary_op(std::ostream& os, const BinaryOpNode<Op>* v) {
146   switch (v->lhs()->dtype().scalar_type()) {
147 #define TYPE_CASE(Type, Name)                                      \
148   case ScalarType::Name:                                           \
149     visit_binary_op<Type>(os, v->lhs(), v->rhs(), v->expr_type()); \
150     break;
151     AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
152 #undef TYPE_CASE
153     default:
154       throw unsupported_dtype();
155   }
156 }
157 
visit(const RampPtr & v)158 void CppPrinter::visit(const RampPtr& v) {
159   visit(alloc<Add>(v->base(), alloc<Mul>(alloc<IntImm>(lane_), v->stride())));
160 }
161 
visit(const BroadcastPtr & v)162 void CppPrinter::visit(const BroadcastPtr& v) {
163   v->value()->accept(this);
164 }
165 
visit(const ModPtr & v)166 void CppPrinter::visit(const ModPtr& v) {
167   dispatch_binary_op(os(), v.get());
168 }
169 
visit(const MaxPtr & v)170 void CppPrinter::visit(const MaxPtr& v) {
171   dispatch_binary_op(os(), v.get());
172 }
173 
visit(const MinPtr & v)174 void CppPrinter::visit(const MinPtr& v) {
175   dispatch_binary_op(os(), v.get());
176 }
177 
visit(const CompareSelectPtr & v)178 void CppPrinter::visit(const CompareSelectPtr& v) {
179   os() << "((" << *v->lhs() << " "
180        << IRPrinter::to_string(v->compare_select_op()) << " " << *v->rhs()
181        << ") ? " << *v->ret_val1() << " : " << *v->ret_val2() << ")";
182 }
183 
visit(const IfThenElsePtr & v)184 void CppPrinter::visit(const IfThenElsePtr& v) {
185   os() << "((" << *v->condition() << ") ? " << *v->true_value() << " : "
186        << *v->false_value() << ")";
187 }
188 
visit(const AllocatePtr & v)189 void CppPrinter::visit(const AllocatePtr& v) {
190   size_t size = v->dtype().byte_size();
191   for (const auto& dim : v->dims()) {
192     IntImmPtr d = to<IntImm>(dim);
193     if (d) {
194       size *= d->value();
195     } else {
196       throw std::runtime_error("Only IntImm dimensions are supported for now");
197     }
198   }
199 
200   emitIndent();
201   os() << v->dtype().ToCppString() << "* " << (*v->buffer_var())
202        << " = static_cast<" << v->dtype().ToCppString() << "*>(malloc(" << size
203        << "));" << '\n';
204 }
205 
visit(const FreePtr & v)206 void CppPrinter::visit(const FreePtr& v) {
207   emitIndent();
208   os() << "free(" << *v->buffer_var() << ");" << '\n';
209 }
210 
visit(const LoadPtr & v)211 void CppPrinter::visit(const LoadPtr& v) {
212   auto flat_idx =
213       flatten_index(v->buf()->dims(), v->indices(), v->buf()->strides());
214   os() << *v->base_handle() << "[" << *flat_idx << "]";
215 }
216 
visit(const StorePtr & v)217 void CppPrinter::visit(const StorePtr& v) {
218   auto flat_idx =
219       flatten_index(v->buf()->dims(), v->indices(), v->buf()->strides());
220   const int lanes = v->value()->dtype().lanes();
221   for (int lane = 0; lane < lanes; lane++) {
222     lane_ = lane;
223     emitIndent();
224     os() << *v->base_handle() << "[" << *flat_idx << "] = " << *v->value()
225          << ";" << '\n';
226   }
227 }
228 
visit(const CastPtr & v)229 void CppPrinter::visit(const CastPtr& v) {
230   os() << "static_cast<" << v->dtype().ToCppString() << ">(" << *v->src_value()
231        << ")";
232 }
233 
visit(const BitCastPtr & v)234 void CppPrinter::visit(const BitCastPtr& v) {
235   os() << "std::bitcast<" << v->src_value()->dtype().ToCppString() << ", "
236        << v->dtype().ToCppString() << ">(" << *v->src_value() << ")";
237 }
238 
visit(const IntrinsicsPtr & v)239 void CppPrinter::visit(const IntrinsicsPtr& v) {
240   if (v->op_type() == kRand || v->op_type() == kSigmoid) {
241     throw std::runtime_error("kRand and kSigmoid are not supported");
242   }
243 
244   os() << "std::" << v->func_name() << "(";
245   for (size_t i = 0; i < v->nparams(); i++) {
246     if (i > 0) {
247       os() << ", ";
248     }
249     os() << *v->param(i);
250   }
251   os() << ")";
252 }
253 
visit(const ExternalCallPtr & v)254 void CppPrinter::visit(const ExternalCallPtr& v) {
255   // The generated code needs to link against functions defined
256   // in external_functions.cpp.
257 
258   auto& func_registry = getNNCFunctionRegistry();
259   if (!func_registry.count(v->func_name())) {
260     throw unimplemented_lowering(v);
261   }
262 
263   std::vector<BufPtr> bufs(v->buf_args());
264   bufs.insert(bufs.begin(), v->buf());
265   auto for_buf = [&](const std::function<void(const BufPtr)>& print_buf) {
266     for (size_t i = 0; i < bufs.size(); i++) {
267       if (i > 0) {
268         os() << ", ";
269       }
270       print_buf(bufs[i]);
271     }
272   };
273 
274   emitIndent();
275   os() << "{" << '\n';
276   indent_++;
277 
278   emitIndent();
279   os() << "void* buf_ptrs[]{";
280   for_buf([&](const BufPtr& b) { os() << *b->base_handle(); });
281   os() << "};" << '\n';
282 
283   emitIndent();
284   os() << "int64_t buf_ranks[]{";
285   for_buf([&](const BufPtr& b) { os() << b->ndim(); });
286   os() << "};" << '\n';
287 
288   emitIndent();
289   os() << "int64_t buf_dims[]{";
290   for_buf([&](const BufPtr& buf) {
291     for (size_t i = 0; i < buf->ndim(); i++) {
292       if (i > 0) {
293         os() << ", ";
294       }
295       os() << *buf->dim(i);
296     }
297   });
298   os() << "};" << '\n';
299 
300   emitIndent();
301   os() << "int8_t buf_dtypes[]{";
302   for_buf([&](const BufPtr& buf) {
303     os() << static_cast<int>(buf->dtype().scalar_type());
304   });
305   os() << "};" << '\n';
306 
307   emitIndent();
308   os() << "int64_t extra_args[]{";
309   for (size_t i = 0; i < v->args().size(); i++) {
310     if (i > 0) {
311       os() << ", ";
312     }
313     os() << *v->args()[i];
314   }
315   os() << "};" << '\n';
316 
317   emitIndent();
318   os() << v->func_name() << "(" << '\n';
319   emitIndent();
320   os() << "    " << bufs.size() << "," << '\n';
321   emitIndent();
322   os() << "    buf_ptrs," << '\n';
323   emitIndent();
324   os() << "    buf_ranks," << '\n';
325   emitIndent();
326   os() << "    buf_dims," << '\n';
327   emitIndent();
328   os() << "    buf_dtypes," << '\n';
329   emitIndent();
330   os() << "    " << v->args().size() << "," << '\n';
331   emitIndent();
332   os() << "    extra_args);" << '\n';
333 
334   indent_--;
335   emitIndent();
336   os() << "}" << '\n';
337 }
338 
visit(const LetPtr & v)339 void CppPrinter::visit(const LetPtr& v) {
340   if (v->var()->dtype().lanes() == 1) {
341     emitIndent();
342     os() << v->var()->dtype().ToCppString() << " " << *v->var() << " = "
343          << *v->value() << ";" << '\n';
344   } else {
345     vector_vars_[v->var()] = v->value();
346   }
347 }
348 
visit(const VarPtr & v)349 void CppPrinter::visit(const VarPtr& v) {
350   if (v->dtype().lanes() == 1) {
351     os() << name_manager()->get_unique_name(v);
352   } else {
353     os() << *vector_vars_.at(v);
354   }
355 }
356 
CppCodeGen(StmtPtr stmt,const std::vector<BufferArg> & buffer_args,at::Device device,const std::string & kernel_func_name)357 CppCodeGen::CppCodeGen(
358     StmtPtr stmt,
359     const std::vector<BufferArg>& buffer_args,
360     at::Device device,
361     const std::string& kernel_func_name)
362     : CodeGen(std::move(stmt), buffer_args, device, kernel_func_name) {
363   init();
364 }
365 
init()366 void CppCodeGen::init() {
367   printer_ = std::make_unique<CppPrinter>(&oss_);
368   var_name_rewriter_ = std::make_unique<CppVarNameRewriter>();
369 
370   apply_visitor(var_name_rewriter_.get());
371 
372   printer_->printPrologue();
373   os() << "void " << kernel_func_name() << "(";
374   const std::vector<BufferArg> buffer_args = this->buffer_args();
375   for (size_t i = 0; i < buffer_args.size(); i++) {
376     if (i > 0) {
377       os() << ", ";
378     }
379     const BufferArg& buffer_arg = buffer_args[i];
380     const VarPtr var = buffer_arg.var();
381     Dtype dtype = buffer_arg.dtype();
382     os() << dtype.ToCppString() << (buffer_arg.isVar() ? " " : "* ") << *var;
383   }
384   os() << ")";
385   stmt()->accept(printer_.get());
386   os() << '\n';
387 }
388 
389 CppCodeGen::~CppCodeGen() = default;
390 
call(const std::vector<CallArg> & args)391 void CppCodeGen::call(const std::vector<CallArg>& args) {
392   // TODO: compile the generated C++ kernel into a library,
393   // and call the library here.
394   os() << "int main() {}" << '\n';
395 }
396 
call_raw(const std::vector<void * > & args)397 void CppCodeGen::call_raw(const std::vector<void*>& args) {
398   // TODO: compile the generated C++ kernel into a library,
399   // and call the library here.
400   os() << "int main() {}" << '\n';
401 }
402 
403 RegisterCodeGen<CppCodeGen> cpp_codegen_reg("cpp_codegen");
404 
405 } // namespace torch::jit::tensorexpr
406