xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/eval.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/eval.h>
2 
3 #include <torch/csrc/jit/jit_log.h>
4 #include <torch/csrc/jit/tensorexpr/external_functions_core.h>
5 #include <torch/csrc/jit/tensorexpr/external_functions_registry.h>
6 
7 #include <c10/util/irange.h>
8 
9 #include <utility>
10 
11 namespace torch::jit::tensorexpr {
12 
13 RegisterCodeGen<SimpleIREvaluator> ir_eval_codegen_reg("simple_ir_eval");
14 
intValue() const15 int64_t InterpValue::intValue() const {
16 #define TYPE_CASE(Type, Name)        \
17   if (dtype_ == k##Name) {           \
18     return int64_t{Name##values[0]}; \
19   }
20   AT_FORALL_INT_TYPES(TYPE_CASE);
21 #undef TYPE_CASE
22   throw unsupported_dtype();
23   return 0;
24 }
25 
26 template <typename T>
mod_value(T lhs,T rhs)27 inline std::enable_if_t<std::is_integral_v<T>, T> mod_value(T lhs, T rhs) {
28   return lhs % rhs;
29 }
30 
31 template <typename T>
mod_value(T lhs,T rhs)32 inline std::enable_if_t<std::is_floating_point_v<T>, T> mod_value(
33     T lhs,
34     T rhs) {
35   return std::fmod(lhs, rhs);
36 }
37 
mod_value(bool lhs,bool rhs)38 inline bool mod_value(bool lhs, bool rhs) {
39   throw std::runtime_error("Attempted modulus of bool");
40 }
41 
42 template <typename T>
div_value(T lhs,T rhs)43 inline std::enable_if_t<std::is_integral_v<T>, T> div_value(T lhs, T rhs) {
44   TORCH_CHECK(rhs != 0, "Division by zero");
45   return lhs / rhs;
46 }
47 
48 template <typename T>
49 inline std::enable_if_t<std::is_floating_point_v<T>, T>
div_value(T lhs,T rhs)50     __ubsan_ignore_float_divide_by_zero__ div_value(T lhs, T rhs) {
51   return lhs / rhs;
52 }
53 
div_value(bool lhs,bool rhs)54 inline bool div_value(bool lhs, bool rhs) {
55   LOG(FATAL) << "Attempted division of bool";
56   return false;
57 }
58 
div_value(c10::Half lhs,c10::Half rhs)59 inline c10::Half div_value(c10::Half lhs, c10::Half rhs) {
60   return lhs / rhs;
61 }
62 
div_value(c10::BFloat16 lhs,c10::BFloat16 rhs)63 inline c10::BFloat16 div_value(c10::BFloat16 lhs, c10::BFloat16 rhs) {
64   return lhs / rhs;
65 }
66 
67 class SimpleIREvaluatorImpl : public IRVisitor {
68  public:
69   SimpleIREvaluatorImpl() = default;
70 
71   ~SimpleIREvaluatorImpl() override = default;
72 
bindBuf(const BufPtr & buf,void * ptr)73   void bindBuf(const BufPtr& buf, void* ptr) {
74     GRAPH_DEBUG("Binding ptr ", ptr, " with buf ", buf->name_hint());
75     buffer_mapping_[buf] = ptr;
76   }
bindVar(const VarPtr & var,const InterpValue & val)77   void bindVar(const VarPtr& var, const InterpValue& val) {
78     eval_context_[var] = val;
79     GRAPH_DEBUG(
80         "Binding value ", val.intValue(), " with var ", var->name_hint());
81   }
82 
evaluateExpr(const ExprPtr & e)83   InterpValue evaluateExpr(const ExprPtr& e) {
84     e->accept(this);
85     return value_;
86   }
87 
value() const88   InterpValue value() const {
89     return value_;
90   }
91 
clear()92   void clear() {
93     eval_context_.clear();
94     buffer_mapping_.clear();
95     internal_buffers_.clear();
96   }
97 
visit(const AddPtr & v)98   TORCH_API void visit(const AddPtr& v) override {
99     visit_binary_op(v);
100   }
visit(const SubPtr & v)101   TORCH_API void visit(const SubPtr& v) override {
102     visit_binary_op(v);
103   }
visit(const MulPtr & v)104   TORCH_API void visit(const MulPtr& v) override {
105     visit_binary_op(v);
106   }
visit(const DivPtr & v)107   TORCH_API void visit(const DivPtr& v) override {
108     visit_binary_op(v);
109   }
visit(const ModPtr & v)110   TORCH_API void visit(const ModPtr& v) override {
111     visit_binary_op(v);
112   }
visit(const MaxPtr & v)113   TORCH_API void visit(const MaxPtr& v) override {
114     visit_binary_op(v, v->propagate_nans());
115   }
visit(const MinPtr & v)116   TORCH_API void visit(const MinPtr& v) override {
117     visit_binary_op(v, v->propagate_nans());
118   }
119 
visit(const AndPtr & v)120   TORCH_API void visit(const AndPtr& v) override {
121     visit_binary_op(v);
122   }
visit(const OrPtr & v)123   TORCH_API void visit(const OrPtr& v) override {
124     visit_binary_op(v);
125   }
visit(const XorPtr & v)126   TORCH_API void visit(const XorPtr& v) override {
127     visit_binary_op(v);
128   }
visit(const LshiftPtr & v)129   TORCH_API void visit(const LshiftPtr& v) override {
130     visit_binary_op(v);
131   }
visit(const RshiftPtr & v)132   TORCH_API void visit(const RshiftPtr& v) override {
133     visit_binary_op(v);
134   }
135 
visit(const CompareSelectPtr & v)136   void visit(const CompareSelectPtr& v) override {
137     visit_compare_select_op(v, v->compare_select_op());
138   }
139 
140   template <typename T>
max_value(T a,T b)141   typename std::enable_if_t<std::is_floating_point_v<T>, T> max_value(
142       T a,
143       T b) {
144     return std::isnan(a) ? a : (std::isnan(b) ? b : (a < b ? b : a));
145   }
146 
147   template <typename T>
max_value(T a,T b)148   typename std::enable_if_t<!std::is_floating_point_v<T>, T> max_value(
149       T a,
150       T b) {
151     return a < b ? b : a;
152   }
153 
154   template <typename T>
min_value(T a,T b)155   typename std::enable_if_t<std::is_floating_point_v<T>, T> min_value(
156       T a,
157       T b) {
158     return std::isnan(a) ? a : (std::isnan(b) ? b : (a < b ? a : b));
159   }
160 
161   template <typename T>
min_value(T a,T b)162   typename std::enable_if_t<!std::is_floating_point_v<T>, T> min_value(
163       T a,
164       T b) {
165     return a < b ? a : b;
166   }
167 
168   template <typename T>
binary_op(const InterpValue & lhs,const InterpValue & rhs,IRNodeType op_type)169   InterpValue binary_op(
170       const InterpValue& lhs,
171       const InterpValue& rhs,
172       IRNodeType op_type) {
173     std::vector<T> lhs_v = lhs.as_vec<T>();
174     std::vector<T> rhs_v = rhs.as_vec<T>();
175     std::vector<T> result_v(lhs_v.size());
176     for (const auto i : c10::irange(lhs_v.size())) {
177       switch (op_type) {
178         case IRNodeType::kAdd:
179           result_v[i] = lhs_v[i] + rhs_v[i];
180           break;
181         case IRNodeType::kSub:
182           result_v[i] = lhs_v[i] - rhs_v[i];
183           break;
184         case IRNodeType::kMul:
185           result_v[i] = lhs_v[i] * rhs_v[i];
186           break;
187         case IRNodeType::kDiv:
188           result_v[i] = div_value(lhs_v[i], rhs_v[i]);
189           break;
190         case IRNodeType::kMod:
191           result_v[i] = mod_value(lhs_v[i], rhs_v[i]);
192           break;
193         case IRNodeType::kMax:
194           result_v[i] = max_value(lhs_v[i], rhs_v[i]);
195           break;
196         case IRNodeType::kMin:
197           result_v[i] = min_value(lhs_v[i], rhs_v[i]);
198           break;
199         default:
200           // TODO: change to a proper error report
201           throw std::runtime_error("invalid operator type");
202       }
203     }
204     return InterpValue(result_v);
205   }
206 
207   template <typename T>
bitwise_binary_op(const InterpValue & lhs,const InterpValue & rhs,IRNodeType op_type)208   InterpValue bitwise_binary_op(
209       const InterpValue& lhs,
210       const InterpValue& rhs,
211       IRNodeType op_type) {
212     std::vector<T> lhs_v = lhs.as_vec<T>();
213     std::vector<T> rhs_v = rhs.as_vec<T>();
214     std::vector<T> result_v(lhs_v.size());
215     for (const auto i : c10::irange(lhs_v.size())) {
216       switch (op_type) {
217         case IRNodeType::kAnd:
218           result_v[i] = lhs_v[i] & rhs_v[i];
219           break;
220         case IRNodeType::kOr:
221           result_v[i] = lhs_v[i] | rhs_v[i];
222           break;
223         case IRNodeType::kXor:
224           result_v[i] = lhs_v[i] ^ rhs_v[i];
225           break;
226         default:
227           // TODO: change to a proper error report
228           throw std::runtime_error("invalid operator type");
229       }
230     }
231     return InterpValue(result_v);
232   }
233 
234   template <typename T>
shift_binary_op(const InterpValue & lhs,const InterpValue & rhs,IRNodeType op_type)235   InterpValue shift_binary_op(
236       const InterpValue& lhs,
237       const InterpValue& rhs,
238       IRNodeType op_type) {
239     std::vector<T> lhs_v = lhs.as_vec<T>();
240     std::vector<T> rhs_v = rhs.as_vec<T>();
241     std::vector<T> result_v(lhs_v.size());
242     for (const auto i : c10::irange(lhs_v.size())) {
243       switch (op_type) {
244         case IRNodeType::kLshift: {
245           auto a = static_cast<std::make_unsigned_t<T>>(lhs_v[i]);
246           result_v[i] = a << rhs_v[i];
247           break;
248         }
249         case IRNodeType::kRshift:
250           result_v[i] = lhs_v[i] >> rhs_v[i];
251           break;
252         default:
253           // TODO: change to a proper error report
254           throw std::runtime_error("invalid operator type");
255       }
256     }
257     return InterpValue(result_v);
258   }
259 
260   template <typename T, typename R>
compare_select_op(const InterpValue & lhs,const InterpValue & rhs,const InterpValue & retval1,const InterpValue & retval2,CompareSelectOperation cmp_op)261   InterpValue compare_select_op(
262       const InterpValue& lhs,
263       const InterpValue& rhs,
264       const InterpValue& retval1,
265       const InterpValue& retval2,
266       CompareSelectOperation cmp_op) {
267     std::vector<T> lhs_v = lhs.as_vec<T>();
268     std::vector<T> rhs_v = rhs.as_vec<T>();
269     std::vector<R> ret_val1_v = retval1.as_vec<R>();
270     std::vector<R> ret_val2_v = retval2.as_vec<R>();
271     std::vector<R> result_v(lhs_v.size());
272     for (const auto i : c10::irange(lhs_v.size())) {
273       switch (cmp_op) {
274         case CompareSelectOperation::kEQ:
275           result_v[i] = (lhs_v[i] == rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i];
276           break;
277         case CompareSelectOperation::kNE:
278           result_v[i] = (lhs_v[i] != rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i];
279           break;
280         case CompareSelectOperation::kGT:
281           result_v[i] = (lhs_v[i] > rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i];
282           break;
283         case CompareSelectOperation::kGE:
284           result_v[i] = (lhs_v[i] >= rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i];
285           break;
286         case CompareSelectOperation::kLT:
287           result_v[i] = (lhs_v[i] < rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i];
288           break;
289         case CompareSelectOperation::kLE:
290           result_v[i] = (lhs_v[i] <= rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i];
291           break;
292         default:
293           // TODO: change to a proper error report
294           throw std::runtime_error("invalid operator type");
295       }
296     }
297     return InterpValue(result_v);
298   }
299 
300   template <
301       typename D,
302       std::enable_if_t<std::is_same_v<
303           decltype(detail::bin_op_deducer(std::declval<D>())),
304           void>>* = nullptr>
visit_binary_op(NodePtr<D> v,bool option=false)305   void visit_binary_op(NodePtr<D> v, bool option = false) {
306     v->lhs()->accept(this);
307     InterpValue lhs_v = value_;
308     v->rhs()->accept(this);
309     InterpValue rhs_v = value_;
310     if (lhs_v.dtype() != rhs_v.dtype()) {
311       throw malformed_input("bad dtype in binary op", v);
312     }
313 
314     IRNodeType expr_type = v->expr_type();
315     if (expr_type == IRNodeType::kAnd || expr_type == IRNodeType::kOr ||
316         expr_type == IRNodeType::kXor) {
317       switch (lhs_v.dtype().scalar_type()) {
318 #define TYPE_CASE(Type, Name)                                  \
319   case ScalarType::Name:                                       \
320     value_ = bitwise_binary_op<Type>(lhs_v, rhs_v, expr_type); \
321     break;
322         AT_FORALL_INT_TYPES(TYPE_CASE);
323 #undef TYPE_CASE
324         case ScalarType::Bool:
325           value_ = bitwise_binary_op<unsigned char>(lhs_v, rhs_v, expr_type);
326           break;
327         default:
328           throw unsupported_dtype();
329       }
330       return;
331     }
332 
333     if (expr_type == IRNodeType::kLshift || expr_type == IRNodeType::kRshift) {
334       switch (lhs_v.dtype().scalar_type()) {
335 #define TYPE_CASE(Type, Name)                                \
336   case ScalarType::Name:                                     \
337     value_ = shift_binary_op<Type>(lhs_v, rhs_v, expr_type); \
338     break;
339         AT_FORALL_INT_TYPES(TYPE_CASE);
340 #undef TYPE_CASE
341         case ScalarType::Bool:
342           value_ = shift_binary_op<unsigned char>(lhs_v, rhs_v, expr_type);
343           break;
344         default:
345           throw unsupported_dtype();
346       }
347       return;
348     }
349 
350     switch (lhs_v.dtype().scalar_type()) {
351 #define TYPE_CASE(Type, Name)                          \
352   case ScalarType::Name:                               \
353     value_ = binary_op<Type>(lhs_v, rhs_v, expr_type); \
354     break;
355       AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE);
356 #undef TYPE_CASE
357       case ScalarType::Bool:
358         value_ = binary_op<unsigned char>(lhs_v, rhs_v, expr_type);
359         break;
360       default:
361         throw unsupported_dtype();
362     }
363   }
364 
365   template <typename T>
compare_select_op_helper(const InterpValue & lhs,const InterpValue & rhs,const InterpValue & retval1,const InterpValue & retval2,CompareSelectOperation cmp_op)366   InterpValue compare_select_op_helper(
367       const InterpValue& lhs,
368       const InterpValue& rhs,
369       const InterpValue& retval1,
370       const InterpValue& retval2,
371       CompareSelectOperation cmp_op) {
372     InterpValue value;
373     switch (retval1.dtype().scalar_type()) {
374 #define TYPE_CASE(Type, Name)                                               \
375   case ScalarType::Name:                                                    \
376     value = compare_select_op<T, Type>(lhs, rhs, retval1, retval2, cmp_op); \
377     break;
378       AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
379 #undef TYPE_CASE
380       default:
381         throw unsupported_dtype();
382     }
383 
384     return value;
385   }
386 
visit_compare_select_op(const CompareSelectPtr & v,CompareSelectOperation cmp_op)387   void visit_compare_select_op(
388       const CompareSelectPtr& v,
389       CompareSelectOperation cmp_op) {
390     v->lhs()->accept(this);
391     InterpValue lhs_v = value_;
392     v->rhs()->accept(this);
393     InterpValue rhs_v = value_;
394     v->ret_val1()->accept(this);
395     InterpValue ret_val1_v = value_;
396     v->ret_val2()->accept(this);
397     InterpValue ret_val2_v = value_;
398 
399     if (lhs_v.dtype() != rhs_v.dtype() ||
400         ret_val1_v.dtype() != ret_val2_v.dtype()) {
401       throw malformed_input("bad dtype in CompareSelect", v);
402     }
403 
404     switch (lhs_v.dtype().scalar_type()) {
405 #define TYPE_CASE(Type, Name)                          \
406   case ScalarType::Name:                               \
407     value_ = compare_select_op_helper<Type>(           \
408         lhs_v, rhs_v, ret_val1_v, ret_val2_v, cmp_op); \
409     break;
410       AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
411 #undef TYPE_CASE
412       default:
413         throw unsupported_dtype();
414     }
415   }
416 
417 #define IMM_VISIT(Type, Name)                            \
418   TORCH_API void visit(const Name##ImmPtr& v) override { \
419     value_ = InterpValue(v->value());                    \
420   }
421   AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT);
422 #undef IMM_VISIT
423 
visit(const BlockPtr & v)424   TORCH_API void visit(const BlockPtr& v) override {
425     BlockPtr last = scope_;
426     scope_ = v;
427     for (const StmtPtr& s : v->stmts()) {
428       s->accept(this);
429     }
430 
431     auto it = var_by_scope_.find(v);
432     if (it != var_by_scope_.end()) {
433       for (const ExprPtr& v : it->second) {
434         eval_context_.erase(v);
435       }
436       var_by_scope_.erase(it);
437     }
438 
439     scope_ = last;
440   }
441 
visit(const VarPtr & v)442   TORCH_API void visit(const VarPtr& v) override {
443     auto iter = eval_context_.find(v);
444     if (iter == eval_context_.end()) {
445       throw malformed_input("could not find Var in context", v);
446     }
447 
448     value_ = iter->second;
449   }
450 
451   // disable ubsan because sometimes this performs out-of-bound casts
452   // e.g. it will cast negative floats to unsigned char
453   template <typename SrcType, typename DstType>
castValues(const Dtype & src_dtype,const InterpValue & v)454   std::vector<DstType> castValues(const Dtype& src_dtype, const InterpValue& v)
455       __ubsan_ignore_undefined__ {
456     const std::vector<SrcType>& src_values = v.as_vec<SrcType>();
457     std::vector<DstType> dst_values(src_values.size());
458     for (int i = 0; i < src_dtype.lanes(); ++i) {
459       // NOLINTNEXTLINE(bugprone-signed-char-misuse)
460       dst_values[i] = static_cast<DstType>(underlyingValue(src_values[i]));
461     }
462     return dst_values;
463   }
464 
465   template <typename SrcType>
doCastFromSrc(const Dtype & src_dtype,const Dtype & dst_dtype,const InterpValue & v)466   void doCastFromSrc(
467       const Dtype& src_dtype,
468       const Dtype& dst_dtype,
469       const InterpValue& v) {
470     switch (dst_dtype.scalar_type()) {
471 #define DST_TYPE_CASE(Type, Name)                                        \
472   case ScalarType::Name:                                                 \
473     this->value_ = InterpValue(castValues<SrcType, Type>(src_dtype, v)); \
474     break;
475       AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, DST_TYPE_CASE);
476 #undef DST_TYPE_CASE
477 #define DST_TYPE_CASE_QUANT(Type, Name, CppType)                           \
478   case ScalarType::Name: {                                                 \
479     std::vector<CppType> vec = castValues<SrcType, CppType>(dst_dtype, v); \
480     std::vector<Type> qvec;                                                \
481     qvec.reserve(vec.size());                                              \
482     for (CppType u : vec) {                                                \
483       qvec.emplace_back(u);                                                \
484     }                                                                      \
485     this->value_ = InterpValue(qvec);                                      \
486   } break;
487       DST_TYPE_CASE_QUANT(c10::quint8, QUInt8, uint8_t)
488       DST_TYPE_CASE_QUANT(c10::qint8, QInt8, int8_t)
489 #undef DST_TYPE_CASE_QUANT
490       default:
491         throw unsupported_dtype();
492     }
493   }
494 
visit(const CastPtr & v)495   TORCH_API void visit(const CastPtr& v) override {
496     ExprPtr src_value = v->src_value();
497     src_value->accept(this);
498     Dtype dst_dtype = v->dtype();
499     Dtype src_dtype = src_value->dtype();
500     if (src_dtype.lanes() != dst_dtype.lanes()) {
501       throw malformed_input("lane mismatch in Cast", v);
502     }
503 
504     if (src_dtype != dst_dtype) {
505       switch (src_dtype.scalar_type()) {
506 #define SRC_TYPE_CASE(Type, Name)                      \
507   case ScalarType::Name:                               \
508     doCastFromSrc<Type>(src_dtype, dst_dtype, value_); \
509     break;
510         AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, SRC_TYPE_CASE);
511         SRC_TYPE_CASE(c10::quint8, QUInt8);
512         SRC_TYPE_CASE(c10::qint8, QInt8);
513 #undef SRC_TYPE_CASE
514         default:
515           throw unsupported_dtype();
516       }
517     }
518   }
519 
520   template <typename SrcType, typename DstType>
bitcastValues(const Dtype & src_dtype,const InterpValue & v)521   std::vector<DstType> bitcastValues(
522       const Dtype& src_dtype,
523       const InterpValue& v) {
524     const std::vector<SrcType>& src_values = v.as_vec<SrcType>();
525     std::vector<DstType> dst_values(src_values.size());
526     for (int i = 0; i < src_dtype.lanes(); ++i) {
527       dst_values[i] = raw_bitcast<DstType>(src_values[i]);
528     }
529     return dst_values;
530   }
531 
532   template <typename SrcType>
doBitCastFromSrc(const Dtype & src_dtype,const Dtype & dst_dtype,const InterpValue & v)533   void doBitCastFromSrc(
534       const Dtype& src_dtype,
535       const Dtype& dst_dtype,
536       const InterpValue& v) {
537     switch (dst_dtype.scalar_type()) {
538 #define DST_TYPE_CASE(Type, Name)                                           \
539   case ScalarType::Name:                                                    \
540     this->value_ = InterpValue(bitcastValues<SrcType, Type>(src_dtype, v)); \
541     break;
542       // bool/half not supported
543       AT_FORALL_SCALAR_TYPES(DST_TYPE_CASE);
544 #undef DST_TYPE_CASE
545       default:
546         throw unsupported_dtype();
547     }
548   }
549 
visit(const BitCastPtr & v)550   TORCH_API void visit(const BitCastPtr& v) override {
551     ExprPtr src_value = v->src_value();
552     src_value->accept(this);
553     Dtype dst_dtype = v->dtype();
554     Dtype src_dtype = src_value->dtype();
555     if (src_dtype.byte_size() != dst_dtype.byte_size()) {
556       throw malformed_input("lane mismatch in Cast", v);
557     }
558     if (src_dtype != dst_dtype) {
559       switch (src_dtype.scalar_type()) {
560 #define SRC_TYPE_CASE(Type, Name)                         \
561   case ScalarType::Name:                                  \
562     doBitCastFromSrc<Type>(src_dtype, dst_dtype, value_); \
563     break;
564         // bool/half not supported
565         AT_FORALL_SCALAR_TYPES(SRC_TYPE_CASE);
566 #undef SRC_TYPE_CASE
567         default:
568           throw unsupported_dtype();
569       }
570     }
571   }
572 
visit(const ForPtr & v)573   TORCH_API void visit(const ForPtr& v) override {
574     ExprPtr var_node = v->var();
575     v->start()->accept(this);
576     auto dtype = value_.dtype();
577     auto start = value_.intValue();
578     v->stop()->accept(this);
579     auto stop = value_.intValue();
580     if (eval_context_.count(var_node)) {
581       throw malformed_input("could not find var_node in For context", v);
582     }
583 
584     for (auto i = start; i < stop; i++) {
585       eval_context_[var_node] = InterpValue(dtype, i);
586       if (v->body()) {
587         v->body()->accept(this);
588       }
589     }
590     eval_context_.erase(var_node);
591   }
592 
visit(const RampPtr & v)593   TORCH_API void visit(const RampPtr& v) override {
594     v->base()->accept(this);
595     auto base = value().intValue();
596     v->stride()->accept(this);
597     auto stride = value().intValue();
598     int lanes = v->lanes();
599 
600     std::vector<int64_t> values(lanes);
601     for (const auto i : c10::irange(lanes)) {
602       values[i] = base + i * stride;
603     }
604 
605     value_ = InterpValue(values);
606   }
607 
visit(const BroadcastPtr & v)608   TORCH_API void visit(const BroadcastPtr& v) override {
609     v->value()->accept(this);
610     InterpValue value = this->value();
611     int lanes = v->lanes();
612     switch (value.dtype().scalar_type()) {
613 #define TYPE_CASE(Type, Name)                     \
614   case ScalarType::Name: {                        \
615     std::vector<Type> v(lanes, value.as<Type>()); \
616     value_ = InterpValue(v);                      \
617   } break;
618       AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
619 #undef TYPE_CASE
620       default:
621         throw unsupported_dtype();
622     }
623   }
624 
visit(const IfThenElsePtr & v)625   TORCH_API void visit(const IfThenElsePtr& v) override {
626     v->condition()->accept(this);
627     bool cond_v = false;
628     switch (value_.dtype().scalar_type()) {
629 #define TYPE_CASE(Type, Name)   \
630   case ScalarType::Name: {      \
631     cond_v = value_.as<Type>(); \
632   } break;
633       AT_FORALL_SCALAR_TYPES_AND(Bool, TYPE_CASE);
634 #undef TYPE_CASE
635       case ScalarType::Half:
636         throw unsupported_dtype("IfThenElse condition can't have Half dtype");
637       case ScalarType::BFloat16:
638         throw unsupported_dtype(
639             "IfThenElse condition can't have BFloat16 dtype");
640       default:
641         throw unsupported_dtype();
642     }
643 
644     if (cond_v) {
645       v->true_value()->accept(this);
646     } else {
647       v->false_value()->accept(this);
648     }
649   }
650 
651   template <typename T>
toLongVec(T && t)652   std::vector<int64_t> toLongVec(T&& t) {
653     return std::vector<int64_t>{std::begin(t), std::end(t)};
654   }
655 
indexVec(const InterpValue & v)656   std::vector<int64_t> indexVec(const InterpValue& v) {
657     switch (v.dtype().scalar_type()) {
658 #define TYPE_CASE(Type, Name) \
659   case ScalarType::Name:      \
660     return toLongVec(v.as_vec<Type>());
661       AT_FORALL_INT_TYPES(TYPE_CASE);
662 #undef TYPE_CASE
663       default:
664         throw unsupported_dtype();
665     }
666     return {};
667   }
668 
check_bounds_throw(int64_t idx,int64_t bound,const BufPtr & buf)669   void check_bounds_throw(int64_t idx, int64_t bound, const BufPtr& buf) {
670     std::stringstream ss;
671     ss << "Index out of bounds in check_bounds. Index: " << idx
672        << "; bounds: [0, " << bound << ").";
673     throw malformed_input(ss.str(), buf);
674   }
675 
check_bounds(const BufPtr & buf,const std::vector<ExprPtr> & indices)676   void check_bounds(const BufPtr& buf, const std::vector<ExprPtr>& indices) {
677     const std::vector<ExprPtr>& dims = buf->dims();
678     if (dims.size() != indices.size()) {
679       // indices are flattened, but not buffer
680       if (indices.size() == 1) {
681         if (dims.size() != buf->strides().size()) {
682           throw malformed_input(
683               "Number of dimensions did not match number of strides", buf);
684         }
685         int64_t buf_size = 1;
686         if (!dims.empty()) {
687           ExprHandle buf_size_expr = ExprHandle(immLike(dims[0], 1));
688           ExprHandle negative_one = ExprHandle(immLike(dims[0], -1));
689           for (const auto& i : c10::irange(dims.size())) {
690             buf_size_expr = buf_size_expr +
691                 ((negative_one + ExprHandle(dims[i])) *
692                  ExprHandle(buf->strides()[i]));
693           }
694           buf_size_expr.node()->accept(this);
695           buf_size = value().intValue();
696         }
697         indices[0]->accept(this);
698         const auto& index_values = indexVec(value());
699         for (auto& j : index_values) {
700           if (j < 0 || j >= buf_size) {
701             check_bounds_throw(j, buf_size, buf);
702           }
703         }
704         return;
705       }
706       throw malformed_input(
707           "dimensions and indices mismatch in check_bounds. Buf has " +
708               std::to_string(dims.size()) + " dimensions and indices has " +
709               std::to_string(indices.size()) + " dimensions.",
710           buf);
711     }
712     for (const auto& i : c10::irange(dims.size())) {
713       auto opt_dim = intValue(dims[i]);
714       if (!opt_dim) {
715         continue;
716       }
717       auto dim_bound = *opt_dim;
718       indices[i]->accept(this);
719       const auto& ithDimIndices = indexVec(value());
720       for (auto& j : ithDimIndices) {
721         if (j < 0 || j >= dim_bound) {
722           check_bounds_throw(j, dim_bound, buf);
723         }
724       }
725     }
726   }
727 
visit(const LoadPtr & v)728   TORCH_API void visit(const LoadPtr& v) override {
729     auto iter = buffer_mapping_.find(v->buf());
730     if (iter == buffer_mapping_.end()) {
731       throw malformed_input("could not find base node in Load", v);
732     }
733     void* ptr = iter->second;
734 
735     check_bounds(v->buf(), v->indices());
736 
737     ExprPtr flat_idx =
738         flatten_index(v->buf()->dims(), v->indices(), v->buf()->strides());
739     flat_idx->accept(this);
740     auto index = indexVec(value());
741     ScalarType v_sdtype = v->dtype().scalar_type();
742     switch (v_sdtype) {
743 #define TYPE_CASE(Type, Name)                        \
744   case ScalarType::Name: {                           \
745     Type* ptr##Name = static_cast<Type*>(ptr);       \
746     std::vector<Type> val(index.size());             \
747     for (const auto i : c10::irange(index.size())) { \
748       val[i] = ptr##Name[index[i]];                  \
749       GRAPH_DEBUG(                                   \
750           "LOAD: ptr=",                              \
751           ptr##Name,                                 \
752           ", buf=",                                  \
753           v->buf()->name_hint(),                     \
754           ", idx=",                                  \
755           index[i],                                  \
756           ", val=",                                  \
757           (int)underlyingValue(val[i]));             \
758     }                                                \
759     value_ = InterpValue(val);                       \
760   } break;
761       AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
762       TYPE_CASE(c10::quint8, QUInt8);
763       TYPE_CASE(c10::qint8, QInt8);
764 #undef TYPE_CASE
765       default:
766         throw unsupported_dtype("scalar type:" + std::to_string(v_sdtype));
767     }
768   }
769 
visit(const StorePtr & v)770   TORCH_API void visit(const StorePtr& v) override {
771     auto iter = buffer_mapping_.find(v->buf());
772     if (iter == buffer_mapping_.end()) {
773       throw malformed_input("could not find base node in Store", v);
774     }
775 
776     void* ptr = iter->second;
777 
778     check_bounds(v->buf(), v->indices());
779 
780     ExprPtr flat_idx =
781         flatten_index(v->buf()->dims(), v->indices(), v->buf()->strides());
782     flat_idx->accept(this);
783     auto index = indexVec(value());
784     ScalarType v_sdtype = v->value()->dtype().scalar_type();
785 
786     switch (v_sdtype) {
787 #define TYPE_CASE(Type, Name)                                   \
788   case ScalarType::Name: {                                      \
789     v->value()->accept(this);                                   \
790     std::vector<Type> value = this->value().as_vec<Type>();     \
791     if (index.size() != value.size()) {                         \
792       throw malformed_input("value size mismatch in Store", v); \
793     }                                                           \
794     Type* ptr##Name = static_cast<Type*>(ptr);                  \
795     for (const auto i : c10::irange(index.size())) {            \
796       GRAPH_DEBUG(                                              \
797           "STORE: ptr=",                                        \
798           ptr##Name,                                            \
799           ", buf=",                                             \
800           v->buf()->name_hint(),                                \
801           ", idx=",                                             \
802           index[i],                                             \
803           ", val=",                                             \
804           (int)underlyingValue(value[i]));                      \
805       ptr##Name[index[i]] = value[i];                           \
806     }                                                           \
807   } break;
808       AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
809       TYPE_CASE(c10::quint8, QUInt8);
810       TYPE_CASE(c10::qint8, QInt8);
811 #undef TYPE_CASE
812       default:
813         throw unsupported_dtype();
814     }
815   }
816 
visit(const ExternalCallPtr & v)817   void visit(const ExternalCallPtr& v) override {
818     auto& func_registry = getNNCFunctionRegistry();
819     if (!func_registry.count(v->func_name())) {
820       throw unimplemented_lowering(v);
821     }
822     GRAPH_DEBUG(
823         "EXTERNAL CALL: func=",
824         v->func_name(),
825         ", buf=",
826         v->buf()->name_hint());
827 
828     std::vector<BufPtr> bufs(v->buf_args());
829     bufs.insert(bufs.begin(), v->buf());
830 
831     std::vector<void*> buf_ptrs;
832     std::vector<int64_t> buf_ranks;
833     std::vector<int64_t> buf_dims;
834     std::vector<int64_t> buf_strides;
835     std::vector<int8_t> buf_dtypes;
836     std::vector<int64_t> extra_args;
837 
838     for (const BufPtr& b : bufs) {
839       auto iter = buffer_mapping_.find(b);
840       if (iter == buffer_mapping_.end()) {
841         throw malformed_input("could not find buf", v);
842       }
843 
844       buf_ptrs.push_back(iter->second);
845       buf_ranks.push_back(b->dims().size());
846       buf_dtypes.push_back((int8_t)b->dtype().scalar_type());
847       for (const ExprPtr& dim_expr : b->dims()) {
848         dim_expr->accept(this);
849         buf_dims.push_back(value().intValue());
850       }
851       for (const ExprPtr& stride_expr : b->strides()) {
852         stride_expr->accept(this);
853         buf_strides.push_back(value().intValue());
854       }
855     }
856     for (const ExprPtr& a : v->args()) {
857       a->accept(this);
858       int64_t val = 0;
859       if (value().dtype() == kLong) {
860         val = value().as<int64_t>();
861       } else if (value().dtype() == kInt) {
862         val = value().intValue();
863       } else if (value().dtype() == kDouble) {
864         auto x = value().as<double>();
865         val = reinterpret_cast<int64_t*>(&x)[0];
866       } else if (value().dtype() == kFloat) {
867         auto x = value().as<float>();
868         val = reinterpret_cast<int64_t*>(&x)[0];
869       } else {
870         throw malformed_input(
871             "extra_args in ExternalCalls must have int64 dtype", v);
872       }
873       extra_args.push_back(val);
874     }
875 
876     auto fn_ptr = func_registry.at(v->func_name());
877     (*fn_ptr)(
878         bufs.size(),
879         buf_ptrs.data(),
880         buf_ranks.data(),
881         buf_dims.data(),
882         buf_strides.data(),
883         buf_dtypes.data(),
884         extra_args.size(),
885         extra_args.data());
886   }
887 
visit(const ExternalCallWithAllocPtr & v)888   void visit(const ExternalCallWithAllocPtr& v) override {
889     auto& func_registry = getNNCFunctionRegistry();
890     if (!func_registry.count(v->func_name())) {
891       throw unimplemented_lowering(v);
892     }
893     GRAPH_DEBUG("EXTERNAL CALL: func=", v->func_name());
894 
895     const auto& bufs_out = v->buf_out_args();
896     const auto& bufs_in = v->buf_args();
897     const auto bufs_in_size = bufs_in.size();
898     const auto bufs_out_size = bufs_out.size();
899 
900     std::vector<void*> buf_ptrs(bufs_in_size + 2 * bufs_out_size);
901     std::vector<int64_t> buf_ranks;
902     std::vector<int64_t> buf_dims;
903     std::vector<int64_t> buf_strides;
904     std::vector<int8_t> buf_dtypes;
905     std::vector<int64_t> extra_args;
906 
907     size_t i = 0;
908     for (const auto& b : bufs_in) {
909       auto iter = buffer_mapping_.find(b);
910       if (iter == buffer_mapping_.end()) {
911         throw malformed_input("could not find buf", v);
912       }
913       buf_ptrs[bufs_out_size + i] = iter->second;
914       // @lint-ignore CLANGTIDY
915       buf_ranks.push_back(b->dims().size());
916       buf_dtypes.push_back((int8_t)b->dtype().scalar_type());
917       for (const auto& dim_expr : b->dims()) {
918         dim_expr->accept(this);
919         buf_dims.push_back(value().intValue());
920       }
921       for (const ExprPtr& stride_expr : b->strides()) {
922         stride_expr->accept(this);
923         buf_strides.push_back(value().intValue());
924       }
925       i++;
926     }
927     for (const auto& a : v->args()) {
928       a->accept(this);
929       int64_t val = 0;
930       if (value().dtype() == kLong) {
931         val = value().as<int64_t>();
932       } else if (value().dtype() == kInt) {
933         val = value().intValue();
934       } else if (value().dtype() == kDouble) {
935         auto x = value().as<double>();
936         val = reinterpret_cast<int64_t*>(&x)[0];
937       } else if (value().dtype() == kFloat) {
938         auto x = value().as<float>();
939         val = reinterpret_cast<int64_t*>(&x)[0];
940       } else {
941         throw malformed_input(
942             "extra_args in ExternalCalls must have int64 dtype", v);
943       }
944       extra_args.push_back(val);
945     }
946 
947     auto fn_ptr = func_registry.at(v->func_name());
948     (*fn_ptr)(
949         bufs_in_size,
950         buf_ptrs.data(),
951         buf_ranks.data(),
952         buf_dims.data(),
953         buf_strides.data(),
954         buf_dtypes.data(),
955         extra_args.size(),
956         extra_args.data());
957 
958     for (i = 0; i < bufs_out_size; ++i) {
959       const auto& buf_out = bufs_out[i];
960       buffer_mapping_[buf_out] = buf_ptrs[i];
961       ext_bufs_free_ptr_[buf_out] = buf_ptrs[bufs_in_size + bufs_out_size + i];
962     }
963   }
964 
965   template <typename TReturn, typename TInput>
visit_intrinsics_helper(const IntrinsicsPtr & v)966   void visit_intrinsics_helper(const IntrinsicsPtr& v) {
967     std::vector<InterpValue> values(v->nparams());
968     for (const auto i : c10::irange(v->nparams())) {
969       v->param(i)->accept(this);
970       values[i] = this->value();
971     }
972     std::vector<TInput> v1;
973     if (!values.empty()) {
974       v1 = values[0].as_vec<TInput>();
975     }
976     std::vector<TInput> v2;
977     if (values.size() >= 2ULL) {
978       v2 = values[1].as_vec<TInput>();
979       if (v1.size() != v2.size()) {
980         throw malformed_input("value size mismatch in Intrinsics", v);
981       }
982     }
983 
984     if (values.size() > 2) {
985       throw unimplemented_lowering(v);
986     }
987 
988     std::vector<TReturn> result(v1.size(), -1);
989     if (values.size() == 1ULL) {
990       for (const auto i : c10::irange(v1.size())) {
991         result[i] = compute_intrinsics<TReturn>(v->op_type(), v1[i]);
992       }
993     } else {
994       for (const auto i : c10::irange(v1.size())) {
995         result[i] = compute_intrinsics<TReturn>(v->op_type(), v1[i], v2[i]);
996       }
997     }
998     value_ = InterpValue(result);
999   }
1000 
visit(const IntrinsicsPtr & v)1001   TORCH_API void visit(const IntrinsicsPtr& v) override {
1002     auto ty = v->dtype().scalar_type();
1003     if (v->op_type() == kIsNan) {
1004       auto inp_dtype = v->params().at(0)->dtype().scalar_type();
1005       if (inp_dtype == ScalarType::Float) {
1006         visit_intrinsics_helper<int, float>(v);
1007       } else if (inp_dtype == ScalarType::Double) {
1008         visit_intrinsics_helper<int, double>(v);
1009       } else if (inp_dtype == ScalarType::Half) {
1010         throw unsupported_dtype(); // TODO
1011       } else if (inp_dtype == ScalarType::BFloat16) {
1012         throw unsupported_dtype(); // TODO
1013       }
1014     } else {
1015       switch (ty) {
1016 #define TYPE_CASE(Type, Name)               \
1017   case ScalarType::Name:                    \
1018     visit_intrinsics_helper<Type, Type>(v); \
1019     break;
1020         AT_FORALL_SCALAR_TYPES(TYPE_CASE);
1021 #undef TYPE_CASE
1022         default:
1023           throw unsupported_dtype();
1024       }
1025     }
1026   }
1027 
visit(const AllocatePtr & v)1028   void visit(const AllocatePtr& v) override {
1029     BufPtr b = v->buf();
1030     std::vector<ExprPtr> dims = b->dims();
1031     int64_t total_byte_size = b->dtype().byte_size();
1032     for (auto& dim : dims) {
1033       dim->accept(this);
1034       total_byte_size *= value_.intValue();
1035     }
1036     auto int_count = (total_byte_size + sizeof(int) - 1) / sizeof(int);
1037     GRAPH_DEBUG(
1038         "ALLOCATE: buf=", v->buf()->name_hint(), ", size=", total_byte_size);
1039     auto buffer = std::make_unique<std::vector<int>>(int_count);
1040     auto iter = buffer_mapping_.find(b);
1041     if (iter != buffer_mapping_.end() && iter->second != nullptr) {
1042       throw std::runtime_error(
1043           "Allocate a buffer that has already been allocated: " +
1044           v->buffer_var()->name_hint());
1045     }
1046     buffer_mapping_[b] = buffer->data();
1047     internal_buffers_.insert(std::make_pair(b, std::move(buffer)));
1048   }
1049 
visit(const PlacementAllocatePtr & v)1050   void visit(const PlacementAllocatePtr& v) override {
1051     buffer_mapping_[v->buf()] = buffer_mapping_.at(v->buf_to_reuse());
1052   }
1053 
visit(const FreePtr & v)1054   void visit(const FreePtr& v) override {
1055     BufPtr b = v->buf();
1056     GRAPH_DEBUG("FREE: buf=", v->buf()->name_hint());
1057     auto count = internal_buffers_.erase(b);
1058     if (count == 0) {
1059       throw std::runtime_error(
1060           "Free a buffer that is not currently bound: " +
1061           v->buffer_var()->name_hint());
1062     }
1063     buffer_mapping_.erase(b);
1064   }
1065 
visit(const FreeExtPtr & v)1066   void visit(const FreeExtPtr& v) override {
1067     const auto& bufs = v->bufs();
1068     const auto bufs_num = bufs.size();
1069     std::vector<void*> buf_ptrs;
1070     for (const auto& buf : bufs) {
1071       if (!ext_bufs_free_ptr_.count(buf)) {
1072         throw std::runtime_error(
1073             "Free an external allocated buffer that does not have corresponding pointer for freeing: " +
1074             buf->base_handle()->name_hint());
1075       }
1076       buf_ptrs.push_back(ext_bufs_free_ptr_[buf]);
1077     }
1078     nnc_aten_free(bufs_num, buf_ptrs.data());
1079   }
1080 
visit(const LetPtr & v)1081   void visit(const LetPtr& v) override {
1082     var_by_scope_[scope_].push_back(v->var());
1083     bindVar(v->var(), evaluateExpr(v->value()));
1084   }
1085 
visit(const CondPtr & v)1086   void visit(const CondPtr& v) override {
1087     v->condition()->accept(this);
1088     if (value().intValue()) {
1089       if (v->true_stmt()) {
1090         v->true_stmt()->accept(this);
1091       }
1092     } else {
1093       if (v->false_stmt()) {
1094         v->false_stmt()->accept(this);
1095       }
1096     }
1097   }
1098 
1099  private:
1100   template <
1101       typename TReturn,
1102       typename TInput,
1103       std::enable_if_t<std::is_floating_point_v<TInput>, int> = 0>
compute_intrinsics(IntrinsicsOp op_type,TInput v)1104   static TReturn compute_intrinsics(IntrinsicsOp op_type, TInput v) {
1105     switch (op_type) {
1106       case kSin:
1107         return std::sin(v);
1108       case kCos:
1109         return std::cos(v);
1110       case kTan:
1111         return std::tan(v);
1112       case kAsin:
1113         return std::asin(v);
1114       case kAcos:
1115         return std::acos(v);
1116       case kAtan:
1117         return std::atan(v);
1118       case kSinh:
1119         return std::sinh(v);
1120       case kCosh:
1121         return std::cosh(v);
1122       case kTanh:
1123         return std::tanh(v);
1124       case kExp:
1125         return std::exp(v);
1126       case kAbs:
1127         return std::abs(v);
1128       case kExpm1:
1129         return std::expm1(v);
1130       case kLog:
1131         return std::log(v);
1132       case kLog2:
1133         return std::log2(v);
1134       case kLog10:
1135         return std::log10(v);
1136       case kLog1p:
1137         return std::log1p(v);
1138       case kErf:
1139         return std::erf(v);
1140       case kErfc:
1141         return std::erfc(v);
1142       case kSqrt:
1143         return std::sqrt(v);
1144       case kRsqrt: {
1145         auto rsqrt = [](TInput v) __ubsan_ignore_float_divide_by_zero__ {
1146           return 1.0f / std::sqrt(v);
1147         };
1148         return rsqrt(v);
1149       }
1150       case kCeil:
1151         return std::ceil(v);
1152       case kFloor:
1153         return std::floor(v);
1154       case kRound:
1155         return std::round(v);
1156       case kTrunc:
1157         return std::trunc(v);
1158       case kLgamma:
1159         return std::lgamma(v);
1160       case kFrac:
1161         TInput intpart;
1162         return std::modf(v, &intpart);
1163       case kIsNan:
1164         return std::isnan(v);
1165       default:
1166         throw std::runtime_error("Invalid op_type: " + std::to_string(op_type));
1167     }
1168   }
1169 
1170   template <
1171       typename TReturn,
1172       typename TInput,
1173       std::enable_if_t<std::is_integral_v<TInput>, int> = 0>
compute_intrinsics(IntrinsicsOp op_type,TInput v)1174   static TReturn compute_intrinsics(IntrinsicsOp op_type, TInput v) {
1175     switch (op_type) {
1176       case kAbs: {
1177         // internal tool complains about calling `abs` on unsigned, the
1178         // following makes the tool happy
1179         using X = std::conditional_t<std::is_unsigned_v<TInput>, int, TInput>;
1180         return std::is_unsigned_v<TInput> ? v : std::abs(static_cast<X>(v));
1181       }
1182       default:
1183         throw std::runtime_error(
1184             "Invalid integral op_type: " + std::to_string(op_type));
1185     }
1186   }
1187 
1188   // specialization for float -> int ops (just kIsNan currently)
compute_intrinsics(IntrinsicsOp op_type,float v)1189   int compute_intrinsics(IntrinsicsOp op_type, float v) {
1190     switch (op_type) {
1191       case kIsNan:
1192         return std::isnan(v);
1193       default:
1194         throw std::runtime_error("Invalid op_type: " + std::to_string(op_type));
1195     }
1196   }
1197 
1198   template <typename TReturn, typename TInput>
compute_intrinsics(IntrinsicsOp op_type,TInput v1,TInput v2)1199   TReturn compute_intrinsics(IntrinsicsOp op_type, TInput v1, TInput v2) {
1200     switch (op_type) {
1201       case kPow:
1202         return std::pow(v1, v2);
1203       case kFmod:
1204         return std::fmod(v1, v2);
1205       case kRemainder:
1206         return std::remainder(v1, v2);
1207       case kAtan2:
1208         return std::atan2(v1, v2);
1209       default:
1210         throw std::runtime_error("Invalid op_type: " + std::to_string(op_type));
1211     }
1212   }
1213 
1214   InterpValue value_;
1215   BlockPtr scope_;
1216   std::unordered_map<ExprPtr, InterpValue> eval_context_;
1217   std::unordered_map<BlockPtr, std::vector<ExprPtr>> var_by_scope_;
1218   std::unordered_map<BufPtr, void*> buffer_mapping_;
1219   std::unordered_map<BufPtr, std::unique_ptr<std::vector<int>>>
1220       internal_buffers_;
1221   std::unordered_map<BufPtr, void*> ext_bufs_free_ptr_;
1222 };
1223 
SimpleIREvaluator(StmtPtr stmt,const std::vector<BufferArg> & buffer_args,at::Device device,const std::string & kernel_func_name)1224 SimpleIREvaluator::SimpleIREvaluator(
1225     StmtPtr stmt,
1226     const std::vector<BufferArg>& buffer_args,
1227     at::Device device,
1228     const std::string& kernel_func_name)
1229     : CodeGen(std::move(stmt), buffer_args, device, kernel_func_name) {
1230   impl_ = std::make_unique<SimpleIREvaluatorImpl>();
1231   expand_intrinsics();
1232 }
1233 
1234 SimpleIREvaluator::~SimpleIREvaluator() = default;
1235 
call(const std::vector<CallArg> & args)1236 void SimpleIREvaluator::call(const std::vector<CallArg>& args) {
1237   std::vector<void*> raw_args(args.size());
1238   for (size_t i = 0; i < args.size(); i++) {
1239     auto const& bufferArg = buffer_args()[i];
1240     auto const& callArg = args[i];
1241     raw_args[i] = argToPtr(bufferArg, callArg);
1242   }
1243   call_raw(raw_args);
1244 }
1245 
call_raw(const std::vector<void * > & args)1246 void SimpleIREvaluator::call_raw(const std::vector<void*>& args) {
1247   if (args.size() != buffer_args().size()) {
1248     throw malformed_input("bad args in IREvaluator call");
1249   }
1250   for (const auto i : c10::irange(args.size())) {
1251     bindArg(buffer_args()[i], args[i]);
1252   }
1253   stmt()->accept(&*impl_);
1254   impl_->clear();
1255 }
1256 
bindArg(const BufferArg & bufArg,void * data)1257 void SimpleIREvaluator::bindArg(const BufferArg& bufArg, void* data) {
1258   if (!bufArg.isVar()) {
1259     impl_->bindBuf(bufArg.buf(), data);
1260     return;
1261   }
1262 
1263   switch (bufArg.dtype().scalar_type()) {
1264 #define TYPE_CASE(Type, Name)                 \
1265   case ScalarType::Name: {                    \
1266     Type typed_data;                          \
1267     memcpy(&typed_data, data, sizeof(Type));  \
1268     impl_->bindVar(bufArg.var(), typed_data); \
1269     break;                                    \
1270   }
1271     AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
1272 #undef TYPE_CASE
1273     default:
1274       throw unsupported_dtype();
1275   }
1276 }
1277 
bindVar(const VarPtr & v,const ExprPtr & e)1278 void SimpleIREvaluator::bindVar(const VarPtr& v, const ExprPtr& e) {
1279   impl_->bindVar(v, impl_->evaluateExpr(e));
1280 }
1281 
value() const1282 InterpValue SimpleIREvaluator::value() const {
1283   return impl_->value();
1284 }
1285 
evalInt(ExprPtr e)1286 std::optional<int64_t> evalInt(ExprPtr e) {
1287   try {
1288     return ExprEval<SimpleIREvaluator>(cast<int64_t>(ExprHandle(std::move(e))))
1289         .value<int64_t>();
1290   } catch (std::runtime_error& err) {
1291     return std::nullopt;
1292   }
1293 }
1294 
1295 } // namespace torch::jit::tensorexpr
1296