1 #include <torch/csrc/jit/tensorexpr/eval.h>
2
3 #include <torch/csrc/jit/jit_log.h>
4 #include <torch/csrc/jit/tensorexpr/external_functions_core.h>
5 #include <torch/csrc/jit/tensorexpr/external_functions_registry.h>
6
7 #include <c10/util/irange.h>
8
9 #include <utility>
10
11 namespace torch::jit::tensorexpr {
12
13 RegisterCodeGen<SimpleIREvaluator> ir_eval_codegen_reg("simple_ir_eval");
14
intValue() const15 int64_t InterpValue::intValue() const {
16 #define TYPE_CASE(Type, Name) \
17 if (dtype_ == k##Name) { \
18 return int64_t{Name##values[0]}; \
19 }
20 AT_FORALL_INT_TYPES(TYPE_CASE);
21 #undef TYPE_CASE
22 throw unsupported_dtype();
23 return 0;
24 }
25
26 template <typename T>
mod_value(T lhs,T rhs)27 inline std::enable_if_t<std::is_integral_v<T>, T> mod_value(T lhs, T rhs) {
28 return lhs % rhs;
29 }
30
31 template <typename T>
mod_value(T lhs,T rhs)32 inline std::enable_if_t<std::is_floating_point_v<T>, T> mod_value(
33 T lhs,
34 T rhs) {
35 return std::fmod(lhs, rhs);
36 }
37
mod_value(bool lhs,bool rhs)38 inline bool mod_value(bool lhs, bool rhs) {
39 throw std::runtime_error("Attempted modulus of bool");
40 }
41
42 template <typename T>
div_value(T lhs,T rhs)43 inline std::enable_if_t<std::is_integral_v<T>, T> div_value(T lhs, T rhs) {
44 TORCH_CHECK(rhs != 0, "Division by zero");
45 return lhs / rhs;
46 }
47
48 template <typename T>
49 inline std::enable_if_t<std::is_floating_point_v<T>, T>
div_value(T lhs,T rhs)50 __ubsan_ignore_float_divide_by_zero__ div_value(T lhs, T rhs) {
51 return lhs / rhs;
52 }
53
div_value(bool lhs,bool rhs)54 inline bool div_value(bool lhs, bool rhs) {
55 LOG(FATAL) << "Attempted division of bool";
56 return false;
57 }
58
div_value(c10::Half lhs,c10::Half rhs)59 inline c10::Half div_value(c10::Half lhs, c10::Half rhs) {
60 return lhs / rhs;
61 }
62
div_value(c10::BFloat16 lhs,c10::BFloat16 rhs)63 inline c10::BFloat16 div_value(c10::BFloat16 lhs, c10::BFloat16 rhs) {
64 return lhs / rhs;
65 }
66
67 class SimpleIREvaluatorImpl : public IRVisitor {
68 public:
69 SimpleIREvaluatorImpl() = default;
70
71 ~SimpleIREvaluatorImpl() override = default;
72
bindBuf(const BufPtr & buf,void * ptr)73 void bindBuf(const BufPtr& buf, void* ptr) {
74 GRAPH_DEBUG("Binding ptr ", ptr, " with buf ", buf->name_hint());
75 buffer_mapping_[buf] = ptr;
76 }
bindVar(const VarPtr & var,const InterpValue & val)77 void bindVar(const VarPtr& var, const InterpValue& val) {
78 eval_context_[var] = val;
79 GRAPH_DEBUG(
80 "Binding value ", val.intValue(), " with var ", var->name_hint());
81 }
82
evaluateExpr(const ExprPtr & e)83 InterpValue evaluateExpr(const ExprPtr& e) {
84 e->accept(this);
85 return value_;
86 }
87
value() const88 InterpValue value() const {
89 return value_;
90 }
91
clear()92 void clear() {
93 eval_context_.clear();
94 buffer_mapping_.clear();
95 internal_buffers_.clear();
96 }
97
visit(const AddPtr & v)98 TORCH_API void visit(const AddPtr& v) override {
99 visit_binary_op(v);
100 }
visit(const SubPtr & v)101 TORCH_API void visit(const SubPtr& v) override {
102 visit_binary_op(v);
103 }
visit(const MulPtr & v)104 TORCH_API void visit(const MulPtr& v) override {
105 visit_binary_op(v);
106 }
visit(const DivPtr & v)107 TORCH_API void visit(const DivPtr& v) override {
108 visit_binary_op(v);
109 }
visit(const ModPtr & v)110 TORCH_API void visit(const ModPtr& v) override {
111 visit_binary_op(v);
112 }
visit(const MaxPtr & v)113 TORCH_API void visit(const MaxPtr& v) override {
114 visit_binary_op(v, v->propagate_nans());
115 }
visit(const MinPtr & v)116 TORCH_API void visit(const MinPtr& v) override {
117 visit_binary_op(v, v->propagate_nans());
118 }
119
visit(const AndPtr & v)120 TORCH_API void visit(const AndPtr& v) override {
121 visit_binary_op(v);
122 }
visit(const OrPtr & v)123 TORCH_API void visit(const OrPtr& v) override {
124 visit_binary_op(v);
125 }
visit(const XorPtr & v)126 TORCH_API void visit(const XorPtr& v) override {
127 visit_binary_op(v);
128 }
visit(const LshiftPtr & v)129 TORCH_API void visit(const LshiftPtr& v) override {
130 visit_binary_op(v);
131 }
visit(const RshiftPtr & v)132 TORCH_API void visit(const RshiftPtr& v) override {
133 visit_binary_op(v);
134 }
135
visit(const CompareSelectPtr & v)136 void visit(const CompareSelectPtr& v) override {
137 visit_compare_select_op(v, v->compare_select_op());
138 }
139
140 template <typename T>
max_value(T a,T b)141 typename std::enable_if_t<std::is_floating_point_v<T>, T> max_value(
142 T a,
143 T b) {
144 return std::isnan(a) ? a : (std::isnan(b) ? b : (a < b ? b : a));
145 }
146
147 template <typename T>
max_value(T a,T b)148 typename std::enable_if_t<!std::is_floating_point_v<T>, T> max_value(
149 T a,
150 T b) {
151 return a < b ? b : a;
152 }
153
154 template <typename T>
min_value(T a,T b)155 typename std::enable_if_t<std::is_floating_point_v<T>, T> min_value(
156 T a,
157 T b) {
158 return std::isnan(a) ? a : (std::isnan(b) ? b : (a < b ? a : b));
159 }
160
161 template <typename T>
min_value(T a,T b)162 typename std::enable_if_t<!std::is_floating_point_v<T>, T> min_value(
163 T a,
164 T b) {
165 return a < b ? a : b;
166 }
167
168 template <typename T>
binary_op(const InterpValue & lhs,const InterpValue & rhs,IRNodeType op_type)169 InterpValue binary_op(
170 const InterpValue& lhs,
171 const InterpValue& rhs,
172 IRNodeType op_type) {
173 std::vector<T> lhs_v = lhs.as_vec<T>();
174 std::vector<T> rhs_v = rhs.as_vec<T>();
175 std::vector<T> result_v(lhs_v.size());
176 for (const auto i : c10::irange(lhs_v.size())) {
177 switch (op_type) {
178 case IRNodeType::kAdd:
179 result_v[i] = lhs_v[i] + rhs_v[i];
180 break;
181 case IRNodeType::kSub:
182 result_v[i] = lhs_v[i] - rhs_v[i];
183 break;
184 case IRNodeType::kMul:
185 result_v[i] = lhs_v[i] * rhs_v[i];
186 break;
187 case IRNodeType::kDiv:
188 result_v[i] = div_value(lhs_v[i], rhs_v[i]);
189 break;
190 case IRNodeType::kMod:
191 result_v[i] = mod_value(lhs_v[i], rhs_v[i]);
192 break;
193 case IRNodeType::kMax:
194 result_v[i] = max_value(lhs_v[i], rhs_v[i]);
195 break;
196 case IRNodeType::kMin:
197 result_v[i] = min_value(lhs_v[i], rhs_v[i]);
198 break;
199 default:
200 // TODO: change to a proper error report
201 throw std::runtime_error("invalid operator type");
202 }
203 }
204 return InterpValue(result_v);
205 }
206
207 template <typename T>
bitwise_binary_op(const InterpValue & lhs,const InterpValue & rhs,IRNodeType op_type)208 InterpValue bitwise_binary_op(
209 const InterpValue& lhs,
210 const InterpValue& rhs,
211 IRNodeType op_type) {
212 std::vector<T> lhs_v = lhs.as_vec<T>();
213 std::vector<T> rhs_v = rhs.as_vec<T>();
214 std::vector<T> result_v(lhs_v.size());
215 for (const auto i : c10::irange(lhs_v.size())) {
216 switch (op_type) {
217 case IRNodeType::kAnd:
218 result_v[i] = lhs_v[i] & rhs_v[i];
219 break;
220 case IRNodeType::kOr:
221 result_v[i] = lhs_v[i] | rhs_v[i];
222 break;
223 case IRNodeType::kXor:
224 result_v[i] = lhs_v[i] ^ rhs_v[i];
225 break;
226 default:
227 // TODO: change to a proper error report
228 throw std::runtime_error("invalid operator type");
229 }
230 }
231 return InterpValue(result_v);
232 }
233
234 template <typename T>
shift_binary_op(const InterpValue & lhs,const InterpValue & rhs,IRNodeType op_type)235 InterpValue shift_binary_op(
236 const InterpValue& lhs,
237 const InterpValue& rhs,
238 IRNodeType op_type) {
239 std::vector<T> lhs_v = lhs.as_vec<T>();
240 std::vector<T> rhs_v = rhs.as_vec<T>();
241 std::vector<T> result_v(lhs_v.size());
242 for (const auto i : c10::irange(lhs_v.size())) {
243 switch (op_type) {
244 case IRNodeType::kLshift: {
245 auto a = static_cast<std::make_unsigned_t<T>>(lhs_v[i]);
246 result_v[i] = a << rhs_v[i];
247 break;
248 }
249 case IRNodeType::kRshift:
250 result_v[i] = lhs_v[i] >> rhs_v[i];
251 break;
252 default:
253 // TODO: change to a proper error report
254 throw std::runtime_error("invalid operator type");
255 }
256 }
257 return InterpValue(result_v);
258 }
259
260 template <typename T, typename R>
compare_select_op(const InterpValue & lhs,const InterpValue & rhs,const InterpValue & retval1,const InterpValue & retval2,CompareSelectOperation cmp_op)261 InterpValue compare_select_op(
262 const InterpValue& lhs,
263 const InterpValue& rhs,
264 const InterpValue& retval1,
265 const InterpValue& retval2,
266 CompareSelectOperation cmp_op) {
267 std::vector<T> lhs_v = lhs.as_vec<T>();
268 std::vector<T> rhs_v = rhs.as_vec<T>();
269 std::vector<R> ret_val1_v = retval1.as_vec<R>();
270 std::vector<R> ret_val2_v = retval2.as_vec<R>();
271 std::vector<R> result_v(lhs_v.size());
272 for (const auto i : c10::irange(lhs_v.size())) {
273 switch (cmp_op) {
274 case CompareSelectOperation::kEQ:
275 result_v[i] = (lhs_v[i] == rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i];
276 break;
277 case CompareSelectOperation::kNE:
278 result_v[i] = (lhs_v[i] != rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i];
279 break;
280 case CompareSelectOperation::kGT:
281 result_v[i] = (lhs_v[i] > rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i];
282 break;
283 case CompareSelectOperation::kGE:
284 result_v[i] = (lhs_v[i] >= rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i];
285 break;
286 case CompareSelectOperation::kLT:
287 result_v[i] = (lhs_v[i] < rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i];
288 break;
289 case CompareSelectOperation::kLE:
290 result_v[i] = (lhs_v[i] <= rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i];
291 break;
292 default:
293 // TODO: change to a proper error report
294 throw std::runtime_error("invalid operator type");
295 }
296 }
297 return InterpValue(result_v);
298 }
299
300 template <
301 typename D,
302 std::enable_if_t<std::is_same_v<
303 decltype(detail::bin_op_deducer(std::declval<D>())),
304 void>>* = nullptr>
visit_binary_op(NodePtr<D> v,bool option=false)305 void visit_binary_op(NodePtr<D> v, bool option = false) {
306 v->lhs()->accept(this);
307 InterpValue lhs_v = value_;
308 v->rhs()->accept(this);
309 InterpValue rhs_v = value_;
310 if (lhs_v.dtype() != rhs_v.dtype()) {
311 throw malformed_input("bad dtype in binary op", v);
312 }
313
314 IRNodeType expr_type = v->expr_type();
315 if (expr_type == IRNodeType::kAnd || expr_type == IRNodeType::kOr ||
316 expr_type == IRNodeType::kXor) {
317 switch (lhs_v.dtype().scalar_type()) {
318 #define TYPE_CASE(Type, Name) \
319 case ScalarType::Name: \
320 value_ = bitwise_binary_op<Type>(lhs_v, rhs_v, expr_type); \
321 break;
322 AT_FORALL_INT_TYPES(TYPE_CASE);
323 #undef TYPE_CASE
324 case ScalarType::Bool:
325 value_ = bitwise_binary_op<unsigned char>(lhs_v, rhs_v, expr_type);
326 break;
327 default:
328 throw unsupported_dtype();
329 }
330 return;
331 }
332
333 if (expr_type == IRNodeType::kLshift || expr_type == IRNodeType::kRshift) {
334 switch (lhs_v.dtype().scalar_type()) {
335 #define TYPE_CASE(Type, Name) \
336 case ScalarType::Name: \
337 value_ = shift_binary_op<Type>(lhs_v, rhs_v, expr_type); \
338 break;
339 AT_FORALL_INT_TYPES(TYPE_CASE);
340 #undef TYPE_CASE
341 case ScalarType::Bool:
342 value_ = shift_binary_op<unsigned char>(lhs_v, rhs_v, expr_type);
343 break;
344 default:
345 throw unsupported_dtype();
346 }
347 return;
348 }
349
350 switch (lhs_v.dtype().scalar_type()) {
351 #define TYPE_CASE(Type, Name) \
352 case ScalarType::Name: \
353 value_ = binary_op<Type>(lhs_v, rhs_v, expr_type); \
354 break;
355 AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE);
356 #undef TYPE_CASE
357 case ScalarType::Bool:
358 value_ = binary_op<unsigned char>(lhs_v, rhs_v, expr_type);
359 break;
360 default:
361 throw unsupported_dtype();
362 }
363 }
364
365 template <typename T>
compare_select_op_helper(const InterpValue & lhs,const InterpValue & rhs,const InterpValue & retval1,const InterpValue & retval2,CompareSelectOperation cmp_op)366 InterpValue compare_select_op_helper(
367 const InterpValue& lhs,
368 const InterpValue& rhs,
369 const InterpValue& retval1,
370 const InterpValue& retval2,
371 CompareSelectOperation cmp_op) {
372 InterpValue value;
373 switch (retval1.dtype().scalar_type()) {
374 #define TYPE_CASE(Type, Name) \
375 case ScalarType::Name: \
376 value = compare_select_op<T, Type>(lhs, rhs, retval1, retval2, cmp_op); \
377 break;
378 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
379 #undef TYPE_CASE
380 default:
381 throw unsupported_dtype();
382 }
383
384 return value;
385 }
386
visit_compare_select_op(const CompareSelectPtr & v,CompareSelectOperation cmp_op)387 void visit_compare_select_op(
388 const CompareSelectPtr& v,
389 CompareSelectOperation cmp_op) {
390 v->lhs()->accept(this);
391 InterpValue lhs_v = value_;
392 v->rhs()->accept(this);
393 InterpValue rhs_v = value_;
394 v->ret_val1()->accept(this);
395 InterpValue ret_val1_v = value_;
396 v->ret_val2()->accept(this);
397 InterpValue ret_val2_v = value_;
398
399 if (lhs_v.dtype() != rhs_v.dtype() ||
400 ret_val1_v.dtype() != ret_val2_v.dtype()) {
401 throw malformed_input("bad dtype in CompareSelect", v);
402 }
403
404 switch (lhs_v.dtype().scalar_type()) {
405 #define TYPE_CASE(Type, Name) \
406 case ScalarType::Name: \
407 value_ = compare_select_op_helper<Type>( \
408 lhs_v, rhs_v, ret_val1_v, ret_val2_v, cmp_op); \
409 break;
410 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
411 #undef TYPE_CASE
412 default:
413 throw unsupported_dtype();
414 }
415 }
416
417 #define IMM_VISIT(Type, Name) \
418 TORCH_API void visit(const Name##ImmPtr& v) override { \
419 value_ = InterpValue(v->value()); \
420 }
421 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT);
422 #undef IMM_VISIT
423
visit(const BlockPtr & v)424 TORCH_API void visit(const BlockPtr& v) override {
425 BlockPtr last = scope_;
426 scope_ = v;
427 for (const StmtPtr& s : v->stmts()) {
428 s->accept(this);
429 }
430
431 auto it = var_by_scope_.find(v);
432 if (it != var_by_scope_.end()) {
433 for (const ExprPtr& v : it->second) {
434 eval_context_.erase(v);
435 }
436 var_by_scope_.erase(it);
437 }
438
439 scope_ = last;
440 }
441
visit(const VarPtr & v)442 TORCH_API void visit(const VarPtr& v) override {
443 auto iter = eval_context_.find(v);
444 if (iter == eval_context_.end()) {
445 throw malformed_input("could not find Var in context", v);
446 }
447
448 value_ = iter->second;
449 }
450
451 // disable ubsan because sometimes this performs out-of-bound casts
452 // e.g. it will cast negative floats to unsigned char
453 template <typename SrcType, typename DstType>
castValues(const Dtype & src_dtype,const InterpValue & v)454 std::vector<DstType> castValues(const Dtype& src_dtype, const InterpValue& v)
455 __ubsan_ignore_undefined__ {
456 const std::vector<SrcType>& src_values = v.as_vec<SrcType>();
457 std::vector<DstType> dst_values(src_values.size());
458 for (int i = 0; i < src_dtype.lanes(); ++i) {
459 // NOLINTNEXTLINE(bugprone-signed-char-misuse)
460 dst_values[i] = static_cast<DstType>(underlyingValue(src_values[i]));
461 }
462 return dst_values;
463 }
464
465 template <typename SrcType>
doCastFromSrc(const Dtype & src_dtype,const Dtype & dst_dtype,const InterpValue & v)466 void doCastFromSrc(
467 const Dtype& src_dtype,
468 const Dtype& dst_dtype,
469 const InterpValue& v) {
470 switch (dst_dtype.scalar_type()) {
471 #define DST_TYPE_CASE(Type, Name) \
472 case ScalarType::Name: \
473 this->value_ = InterpValue(castValues<SrcType, Type>(src_dtype, v)); \
474 break;
475 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, DST_TYPE_CASE);
476 #undef DST_TYPE_CASE
477 #define DST_TYPE_CASE_QUANT(Type, Name, CppType) \
478 case ScalarType::Name: { \
479 std::vector<CppType> vec = castValues<SrcType, CppType>(dst_dtype, v); \
480 std::vector<Type> qvec; \
481 qvec.reserve(vec.size()); \
482 for (CppType u : vec) { \
483 qvec.emplace_back(u); \
484 } \
485 this->value_ = InterpValue(qvec); \
486 } break;
487 DST_TYPE_CASE_QUANT(c10::quint8, QUInt8, uint8_t)
488 DST_TYPE_CASE_QUANT(c10::qint8, QInt8, int8_t)
489 #undef DST_TYPE_CASE_QUANT
490 default:
491 throw unsupported_dtype();
492 }
493 }
494
visit(const CastPtr & v)495 TORCH_API void visit(const CastPtr& v) override {
496 ExprPtr src_value = v->src_value();
497 src_value->accept(this);
498 Dtype dst_dtype = v->dtype();
499 Dtype src_dtype = src_value->dtype();
500 if (src_dtype.lanes() != dst_dtype.lanes()) {
501 throw malformed_input("lane mismatch in Cast", v);
502 }
503
504 if (src_dtype != dst_dtype) {
505 switch (src_dtype.scalar_type()) {
506 #define SRC_TYPE_CASE(Type, Name) \
507 case ScalarType::Name: \
508 doCastFromSrc<Type>(src_dtype, dst_dtype, value_); \
509 break;
510 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, SRC_TYPE_CASE);
511 SRC_TYPE_CASE(c10::quint8, QUInt8);
512 SRC_TYPE_CASE(c10::qint8, QInt8);
513 #undef SRC_TYPE_CASE
514 default:
515 throw unsupported_dtype();
516 }
517 }
518 }
519
520 template <typename SrcType, typename DstType>
bitcastValues(const Dtype & src_dtype,const InterpValue & v)521 std::vector<DstType> bitcastValues(
522 const Dtype& src_dtype,
523 const InterpValue& v) {
524 const std::vector<SrcType>& src_values = v.as_vec<SrcType>();
525 std::vector<DstType> dst_values(src_values.size());
526 for (int i = 0; i < src_dtype.lanes(); ++i) {
527 dst_values[i] = raw_bitcast<DstType>(src_values[i]);
528 }
529 return dst_values;
530 }
531
532 template <typename SrcType>
doBitCastFromSrc(const Dtype & src_dtype,const Dtype & dst_dtype,const InterpValue & v)533 void doBitCastFromSrc(
534 const Dtype& src_dtype,
535 const Dtype& dst_dtype,
536 const InterpValue& v) {
537 switch (dst_dtype.scalar_type()) {
538 #define DST_TYPE_CASE(Type, Name) \
539 case ScalarType::Name: \
540 this->value_ = InterpValue(bitcastValues<SrcType, Type>(src_dtype, v)); \
541 break;
542 // bool/half not supported
543 AT_FORALL_SCALAR_TYPES(DST_TYPE_CASE);
544 #undef DST_TYPE_CASE
545 default:
546 throw unsupported_dtype();
547 }
548 }
549
visit(const BitCastPtr & v)550 TORCH_API void visit(const BitCastPtr& v) override {
551 ExprPtr src_value = v->src_value();
552 src_value->accept(this);
553 Dtype dst_dtype = v->dtype();
554 Dtype src_dtype = src_value->dtype();
555 if (src_dtype.byte_size() != dst_dtype.byte_size()) {
556 throw malformed_input("lane mismatch in Cast", v);
557 }
558 if (src_dtype != dst_dtype) {
559 switch (src_dtype.scalar_type()) {
560 #define SRC_TYPE_CASE(Type, Name) \
561 case ScalarType::Name: \
562 doBitCastFromSrc<Type>(src_dtype, dst_dtype, value_); \
563 break;
564 // bool/half not supported
565 AT_FORALL_SCALAR_TYPES(SRC_TYPE_CASE);
566 #undef SRC_TYPE_CASE
567 default:
568 throw unsupported_dtype();
569 }
570 }
571 }
572
visit(const ForPtr & v)573 TORCH_API void visit(const ForPtr& v) override {
574 ExprPtr var_node = v->var();
575 v->start()->accept(this);
576 auto dtype = value_.dtype();
577 auto start = value_.intValue();
578 v->stop()->accept(this);
579 auto stop = value_.intValue();
580 if (eval_context_.count(var_node)) {
581 throw malformed_input("could not find var_node in For context", v);
582 }
583
584 for (auto i = start; i < stop; i++) {
585 eval_context_[var_node] = InterpValue(dtype, i);
586 if (v->body()) {
587 v->body()->accept(this);
588 }
589 }
590 eval_context_.erase(var_node);
591 }
592
visit(const RampPtr & v)593 TORCH_API void visit(const RampPtr& v) override {
594 v->base()->accept(this);
595 auto base = value().intValue();
596 v->stride()->accept(this);
597 auto stride = value().intValue();
598 int lanes = v->lanes();
599
600 std::vector<int64_t> values(lanes);
601 for (const auto i : c10::irange(lanes)) {
602 values[i] = base + i * stride;
603 }
604
605 value_ = InterpValue(values);
606 }
607
visit(const BroadcastPtr & v)608 TORCH_API void visit(const BroadcastPtr& v) override {
609 v->value()->accept(this);
610 InterpValue value = this->value();
611 int lanes = v->lanes();
612 switch (value.dtype().scalar_type()) {
613 #define TYPE_CASE(Type, Name) \
614 case ScalarType::Name: { \
615 std::vector<Type> v(lanes, value.as<Type>()); \
616 value_ = InterpValue(v); \
617 } break;
618 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
619 #undef TYPE_CASE
620 default:
621 throw unsupported_dtype();
622 }
623 }
624
visit(const IfThenElsePtr & v)625 TORCH_API void visit(const IfThenElsePtr& v) override {
626 v->condition()->accept(this);
627 bool cond_v = false;
628 switch (value_.dtype().scalar_type()) {
629 #define TYPE_CASE(Type, Name) \
630 case ScalarType::Name: { \
631 cond_v = value_.as<Type>(); \
632 } break;
633 AT_FORALL_SCALAR_TYPES_AND(Bool, TYPE_CASE);
634 #undef TYPE_CASE
635 case ScalarType::Half:
636 throw unsupported_dtype("IfThenElse condition can't have Half dtype");
637 case ScalarType::BFloat16:
638 throw unsupported_dtype(
639 "IfThenElse condition can't have BFloat16 dtype");
640 default:
641 throw unsupported_dtype();
642 }
643
644 if (cond_v) {
645 v->true_value()->accept(this);
646 } else {
647 v->false_value()->accept(this);
648 }
649 }
650
651 template <typename T>
toLongVec(T && t)652 std::vector<int64_t> toLongVec(T&& t) {
653 return std::vector<int64_t>{std::begin(t), std::end(t)};
654 }
655
indexVec(const InterpValue & v)656 std::vector<int64_t> indexVec(const InterpValue& v) {
657 switch (v.dtype().scalar_type()) {
658 #define TYPE_CASE(Type, Name) \
659 case ScalarType::Name: \
660 return toLongVec(v.as_vec<Type>());
661 AT_FORALL_INT_TYPES(TYPE_CASE);
662 #undef TYPE_CASE
663 default:
664 throw unsupported_dtype();
665 }
666 return {};
667 }
668
check_bounds_throw(int64_t idx,int64_t bound,const BufPtr & buf)669 void check_bounds_throw(int64_t idx, int64_t bound, const BufPtr& buf) {
670 std::stringstream ss;
671 ss << "Index out of bounds in check_bounds. Index: " << idx
672 << "; bounds: [0, " << bound << ").";
673 throw malformed_input(ss.str(), buf);
674 }
675
check_bounds(const BufPtr & buf,const std::vector<ExprPtr> & indices)676 void check_bounds(const BufPtr& buf, const std::vector<ExprPtr>& indices) {
677 const std::vector<ExprPtr>& dims = buf->dims();
678 if (dims.size() != indices.size()) {
679 // indices are flattened, but not buffer
680 if (indices.size() == 1) {
681 if (dims.size() != buf->strides().size()) {
682 throw malformed_input(
683 "Number of dimensions did not match number of strides", buf);
684 }
685 int64_t buf_size = 1;
686 if (!dims.empty()) {
687 ExprHandle buf_size_expr = ExprHandle(immLike(dims[0], 1));
688 ExprHandle negative_one = ExprHandle(immLike(dims[0], -1));
689 for (const auto& i : c10::irange(dims.size())) {
690 buf_size_expr = buf_size_expr +
691 ((negative_one + ExprHandle(dims[i])) *
692 ExprHandle(buf->strides()[i]));
693 }
694 buf_size_expr.node()->accept(this);
695 buf_size = value().intValue();
696 }
697 indices[0]->accept(this);
698 const auto& index_values = indexVec(value());
699 for (auto& j : index_values) {
700 if (j < 0 || j >= buf_size) {
701 check_bounds_throw(j, buf_size, buf);
702 }
703 }
704 return;
705 }
706 throw malformed_input(
707 "dimensions and indices mismatch in check_bounds. Buf has " +
708 std::to_string(dims.size()) + " dimensions and indices has " +
709 std::to_string(indices.size()) + " dimensions.",
710 buf);
711 }
712 for (const auto& i : c10::irange(dims.size())) {
713 auto opt_dim = intValue(dims[i]);
714 if (!opt_dim) {
715 continue;
716 }
717 auto dim_bound = *opt_dim;
718 indices[i]->accept(this);
719 const auto& ithDimIndices = indexVec(value());
720 for (auto& j : ithDimIndices) {
721 if (j < 0 || j >= dim_bound) {
722 check_bounds_throw(j, dim_bound, buf);
723 }
724 }
725 }
726 }
727
visit(const LoadPtr & v)728 TORCH_API void visit(const LoadPtr& v) override {
729 auto iter = buffer_mapping_.find(v->buf());
730 if (iter == buffer_mapping_.end()) {
731 throw malformed_input("could not find base node in Load", v);
732 }
733 void* ptr = iter->second;
734
735 check_bounds(v->buf(), v->indices());
736
737 ExprPtr flat_idx =
738 flatten_index(v->buf()->dims(), v->indices(), v->buf()->strides());
739 flat_idx->accept(this);
740 auto index = indexVec(value());
741 ScalarType v_sdtype = v->dtype().scalar_type();
742 switch (v_sdtype) {
743 #define TYPE_CASE(Type, Name) \
744 case ScalarType::Name: { \
745 Type* ptr##Name = static_cast<Type*>(ptr); \
746 std::vector<Type> val(index.size()); \
747 for (const auto i : c10::irange(index.size())) { \
748 val[i] = ptr##Name[index[i]]; \
749 GRAPH_DEBUG( \
750 "LOAD: ptr=", \
751 ptr##Name, \
752 ", buf=", \
753 v->buf()->name_hint(), \
754 ", idx=", \
755 index[i], \
756 ", val=", \
757 (int)underlyingValue(val[i])); \
758 } \
759 value_ = InterpValue(val); \
760 } break;
761 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
762 TYPE_CASE(c10::quint8, QUInt8);
763 TYPE_CASE(c10::qint8, QInt8);
764 #undef TYPE_CASE
765 default:
766 throw unsupported_dtype("scalar type:" + std::to_string(v_sdtype));
767 }
768 }
769
visit(const StorePtr & v)770 TORCH_API void visit(const StorePtr& v) override {
771 auto iter = buffer_mapping_.find(v->buf());
772 if (iter == buffer_mapping_.end()) {
773 throw malformed_input("could not find base node in Store", v);
774 }
775
776 void* ptr = iter->second;
777
778 check_bounds(v->buf(), v->indices());
779
780 ExprPtr flat_idx =
781 flatten_index(v->buf()->dims(), v->indices(), v->buf()->strides());
782 flat_idx->accept(this);
783 auto index = indexVec(value());
784 ScalarType v_sdtype = v->value()->dtype().scalar_type();
785
786 switch (v_sdtype) {
787 #define TYPE_CASE(Type, Name) \
788 case ScalarType::Name: { \
789 v->value()->accept(this); \
790 std::vector<Type> value = this->value().as_vec<Type>(); \
791 if (index.size() != value.size()) { \
792 throw malformed_input("value size mismatch in Store", v); \
793 } \
794 Type* ptr##Name = static_cast<Type*>(ptr); \
795 for (const auto i : c10::irange(index.size())) { \
796 GRAPH_DEBUG( \
797 "STORE: ptr=", \
798 ptr##Name, \
799 ", buf=", \
800 v->buf()->name_hint(), \
801 ", idx=", \
802 index[i], \
803 ", val=", \
804 (int)underlyingValue(value[i])); \
805 ptr##Name[index[i]] = value[i]; \
806 } \
807 } break;
808 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
809 TYPE_CASE(c10::quint8, QUInt8);
810 TYPE_CASE(c10::qint8, QInt8);
811 #undef TYPE_CASE
812 default:
813 throw unsupported_dtype();
814 }
815 }
816
visit(const ExternalCallPtr & v)817 void visit(const ExternalCallPtr& v) override {
818 auto& func_registry = getNNCFunctionRegistry();
819 if (!func_registry.count(v->func_name())) {
820 throw unimplemented_lowering(v);
821 }
822 GRAPH_DEBUG(
823 "EXTERNAL CALL: func=",
824 v->func_name(),
825 ", buf=",
826 v->buf()->name_hint());
827
828 std::vector<BufPtr> bufs(v->buf_args());
829 bufs.insert(bufs.begin(), v->buf());
830
831 std::vector<void*> buf_ptrs;
832 std::vector<int64_t> buf_ranks;
833 std::vector<int64_t> buf_dims;
834 std::vector<int64_t> buf_strides;
835 std::vector<int8_t> buf_dtypes;
836 std::vector<int64_t> extra_args;
837
838 for (const BufPtr& b : bufs) {
839 auto iter = buffer_mapping_.find(b);
840 if (iter == buffer_mapping_.end()) {
841 throw malformed_input("could not find buf", v);
842 }
843
844 buf_ptrs.push_back(iter->second);
845 buf_ranks.push_back(b->dims().size());
846 buf_dtypes.push_back((int8_t)b->dtype().scalar_type());
847 for (const ExprPtr& dim_expr : b->dims()) {
848 dim_expr->accept(this);
849 buf_dims.push_back(value().intValue());
850 }
851 for (const ExprPtr& stride_expr : b->strides()) {
852 stride_expr->accept(this);
853 buf_strides.push_back(value().intValue());
854 }
855 }
856 for (const ExprPtr& a : v->args()) {
857 a->accept(this);
858 int64_t val = 0;
859 if (value().dtype() == kLong) {
860 val = value().as<int64_t>();
861 } else if (value().dtype() == kInt) {
862 val = value().intValue();
863 } else if (value().dtype() == kDouble) {
864 auto x = value().as<double>();
865 val = reinterpret_cast<int64_t*>(&x)[0];
866 } else if (value().dtype() == kFloat) {
867 auto x = value().as<float>();
868 val = reinterpret_cast<int64_t*>(&x)[0];
869 } else {
870 throw malformed_input(
871 "extra_args in ExternalCalls must have int64 dtype", v);
872 }
873 extra_args.push_back(val);
874 }
875
876 auto fn_ptr = func_registry.at(v->func_name());
877 (*fn_ptr)(
878 bufs.size(),
879 buf_ptrs.data(),
880 buf_ranks.data(),
881 buf_dims.data(),
882 buf_strides.data(),
883 buf_dtypes.data(),
884 extra_args.size(),
885 extra_args.data());
886 }
887
visit(const ExternalCallWithAllocPtr & v)888 void visit(const ExternalCallWithAllocPtr& v) override {
889 auto& func_registry = getNNCFunctionRegistry();
890 if (!func_registry.count(v->func_name())) {
891 throw unimplemented_lowering(v);
892 }
893 GRAPH_DEBUG("EXTERNAL CALL: func=", v->func_name());
894
895 const auto& bufs_out = v->buf_out_args();
896 const auto& bufs_in = v->buf_args();
897 const auto bufs_in_size = bufs_in.size();
898 const auto bufs_out_size = bufs_out.size();
899
900 std::vector<void*> buf_ptrs(bufs_in_size + 2 * bufs_out_size);
901 std::vector<int64_t> buf_ranks;
902 std::vector<int64_t> buf_dims;
903 std::vector<int64_t> buf_strides;
904 std::vector<int8_t> buf_dtypes;
905 std::vector<int64_t> extra_args;
906
907 size_t i = 0;
908 for (const auto& b : bufs_in) {
909 auto iter = buffer_mapping_.find(b);
910 if (iter == buffer_mapping_.end()) {
911 throw malformed_input("could not find buf", v);
912 }
913 buf_ptrs[bufs_out_size + i] = iter->second;
914 // @lint-ignore CLANGTIDY
915 buf_ranks.push_back(b->dims().size());
916 buf_dtypes.push_back((int8_t)b->dtype().scalar_type());
917 for (const auto& dim_expr : b->dims()) {
918 dim_expr->accept(this);
919 buf_dims.push_back(value().intValue());
920 }
921 for (const ExprPtr& stride_expr : b->strides()) {
922 stride_expr->accept(this);
923 buf_strides.push_back(value().intValue());
924 }
925 i++;
926 }
927 for (const auto& a : v->args()) {
928 a->accept(this);
929 int64_t val = 0;
930 if (value().dtype() == kLong) {
931 val = value().as<int64_t>();
932 } else if (value().dtype() == kInt) {
933 val = value().intValue();
934 } else if (value().dtype() == kDouble) {
935 auto x = value().as<double>();
936 val = reinterpret_cast<int64_t*>(&x)[0];
937 } else if (value().dtype() == kFloat) {
938 auto x = value().as<float>();
939 val = reinterpret_cast<int64_t*>(&x)[0];
940 } else {
941 throw malformed_input(
942 "extra_args in ExternalCalls must have int64 dtype", v);
943 }
944 extra_args.push_back(val);
945 }
946
947 auto fn_ptr = func_registry.at(v->func_name());
948 (*fn_ptr)(
949 bufs_in_size,
950 buf_ptrs.data(),
951 buf_ranks.data(),
952 buf_dims.data(),
953 buf_strides.data(),
954 buf_dtypes.data(),
955 extra_args.size(),
956 extra_args.data());
957
958 for (i = 0; i < bufs_out_size; ++i) {
959 const auto& buf_out = bufs_out[i];
960 buffer_mapping_[buf_out] = buf_ptrs[i];
961 ext_bufs_free_ptr_[buf_out] = buf_ptrs[bufs_in_size + bufs_out_size + i];
962 }
963 }
964
965 template <typename TReturn, typename TInput>
visit_intrinsics_helper(const IntrinsicsPtr & v)966 void visit_intrinsics_helper(const IntrinsicsPtr& v) {
967 std::vector<InterpValue> values(v->nparams());
968 for (const auto i : c10::irange(v->nparams())) {
969 v->param(i)->accept(this);
970 values[i] = this->value();
971 }
972 std::vector<TInput> v1;
973 if (!values.empty()) {
974 v1 = values[0].as_vec<TInput>();
975 }
976 std::vector<TInput> v2;
977 if (values.size() >= 2ULL) {
978 v2 = values[1].as_vec<TInput>();
979 if (v1.size() != v2.size()) {
980 throw malformed_input("value size mismatch in Intrinsics", v);
981 }
982 }
983
984 if (values.size() > 2) {
985 throw unimplemented_lowering(v);
986 }
987
988 std::vector<TReturn> result(v1.size(), -1);
989 if (values.size() == 1ULL) {
990 for (const auto i : c10::irange(v1.size())) {
991 result[i] = compute_intrinsics<TReturn>(v->op_type(), v1[i]);
992 }
993 } else {
994 for (const auto i : c10::irange(v1.size())) {
995 result[i] = compute_intrinsics<TReturn>(v->op_type(), v1[i], v2[i]);
996 }
997 }
998 value_ = InterpValue(result);
999 }
1000
visit(const IntrinsicsPtr & v)1001 TORCH_API void visit(const IntrinsicsPtr& v) override {
1002 auto ty = v->dtype().scalar_type();
1003 if (v->op_type() == kIsNan) {
1004 auto inp_dtype = v->params().at(0)->dtype().scalar_type();
1005 if (inp_dtype == ScalarType::Float) {
1006 visit_intrinsics_helper<int, float>(v);
1007 } else if (inp_dtype == ScalarType::Double) {
1008 visit_intrinsics_helper<int, double>(v);
1009 } else if (inp_dtype == ScalarType::Half) {
1010 throw unsupported_dtype(); // TODO
1011 } else if (inp_dtype == ScalarType::BFloat16) {
1012 throw unsupported_dtype(); // TODO
1013 }
1014 } else {
1015 switch (ty) {
1016 #define TYPE_CASE(Type, Name) \
1017 case ScalarType::Name: \
1018 visit_intrinsics_helper<Type, Type>(v); \
1019 break;
1020 AT_FORALL_SCALAR_TYPES(TYPE_CASE);
1021 #undef TYPE_CASE
1022 default:
1023 throw unsupported_dtype();
1024 }
1025 }
1026 }
1027
visit(const AllocatePtr & v)1028 void visit(const AllocatePtr& v) override {
1029 BufPtr b = v->buf();
1030 std::vector<ExprPtr> dims = b->dims();
1031 int64_t total_byte_size = b->dtype().byte_size();
1032 for (auto& dim : dims) {
1033 dim->accept(this);
1034 total_byte_size *= value_.intValue();
1035 }
1036 auto int_count = (total_byte_size + sizeof(int) - 1) / sizeof(int);
1037 GRAPH_DEBUG(
1038 "ALLOCATE: buf=", v->buf()->name_hint(), ", size=", total_byte_size);
1039 auto buffer = std::make_unique<std::vector<int>>(int_count);
1040 auto iter = buffer_mapping_.find(b);
1041 if (iter != buffer_mapping_.end() && iter->second != nullptr) {
1042 throw std::runtime_error(
1043 "Allocate a buffer that has already been allocated: " +
1044 v->buffer_var()->name_hint());
1045 }
1046 buffer_mapping_[b] = buffer->data();
1047 internal_buffers_.insert(std::make_pair(b, std::move(buffer)));
1048 }
1049
visit(const PlacementAllocatePtr & v)1050 void visit(const PlacementAllocatePtr& v) override {
1051 buffer_mapping_[v->buf()] = buffer_mapping_.at(v->buf_to_reuse());
1052 }
1053
visit(const FreePtr & v)1054 void visit(const FreePtr& v) override {
1055 BufPtr b = v->buf();
1056 GRAPH_DEBUG("FREE: buf=", v->buf()->name_hint());
1057 auto count = internal_buffers_.erase(b);
1058 if (count == 0) {
1059 throw std::runtime_error(
1060 "Free a buffer that is not currently bound: " +
1061 v->buffer_var()->name_hint());
1062 }
1063 buffer_mapping_.erase(b);
1064 }
1065
visit(const FreeExtPtr & v)1066 void visit(const FreeExtPtr& v) override {
1067 const auto& bufs = v->bufs();
1068 const auto bufs_num = bufs.size();
1069 std::vector<void*> buf_ptrs;
1070 for (const auto& buf : bufs) {
1071 if (!ext_bufs_free_ptr_.count(buf)) {
1072 throw std::runtime_error(
1073 "Free an external allocated buffer that does not have corresponding pointer for freeing: " +
1074 buf->base_handle()->name_hint());
1075 }
1076 buf_ptrs.push_back(ext_bufs_free_ptr_[buf]);
1077 }
1078 nnc_aten_free(bufs_num, buf_ptrs.data());
1079 }
1080
visit(const LetPtr & v)1081 void visit(const LetPtr& v) override {
1082 var_by_scope_[scope_].push_back(v->var());
1083 bindVar(v->var(), evaluateExpr(v->value()));
1084 }
1085
visit(const CondPtr & v)1086 void visit(const CondPtr& v) override {
1087 v->condition()->accept(this);
1088 if (value().intValue()) {
1089 if (v->true_stmt()) {
1090 v->true_stmt()->accept(this);
1091 }
1092 } else {
1093 if (v->false_stmt()) {
1094 v->false_stmt()->accept(this);
1095 }
1096 }
1097 }
1098
1099 private:
1100 template <
1101 typename TReturn,
1102 typename TInput,
1103 std::enable_if_t<std::is_floating_point_v<TInput>, int> = 0>
compute_intrinsics(IntrinsicsOp op_type,TInput v)1104 static TReturn compute_intrinsics(IntrinsicsOp op_type, TInput v) {
1105 switch (op_type) {
1106 case kSin:
1107 return std::sin(v);
1108 case kCos:
1109 return std::cos(v);
1110 case kTan:
1111 return std::tan(v);
1112 case kAsin:
1113 return std::asin(v);
1114 case kAcos:
1115 return std::acos(v);
1116 case kAtan:
1117 return std::atan(v);
1118 case kSinh:
1119 return std::sinh(v);
1120 case kCosh:
1121 return std::cosh(v);
1122 case kTanh:
1123 return std::tanh(v);
1124 case kExp:
1125 return std::exp(v);
1126 case kAbs:
1127 return std::abs(v);
1128 case kExpm1:
1129 return std::expm1(v);
1130 case kLog:
1131 return std::log(v);
1132 case kLog2:
1133 return std::log2(v);
1134 case kLog10:
1135 return std::log10(v);
1136 case kLog1p:
1137 return std::log1p(v);
1138 case kErf:
1139 return std::erf(v);
1140 case kErfc:
1141 return std::erfc(v);
1142 case kSqrt:
1143 return std::sqrt(v);
1144 case kRsqrt: {
1145 auto rsqrt = [](TInput v) __ubsan_ignore_float_divide_by_zero__ {
1146 return 1.0f / std::sqrt(v);
1147 };
1148 return rsqrt(v);
1149 }
1150 case kCeil:
1151 return std::ceil(v);
1152 case kFloor:
1153 return std::floor(v);
1154 case kRound:
1155 return std::round(v);
1156 case kTrunc:
1157 return std::trunc(v);
1158 case kLgamma:
1159 return std::lgamma(v);
1160 case kFrac:
1161 TInput intpart;
1162 return std::modf(v, &intpart);
1163 case kIsNan:
1164 return std::isnan(v);
1165 default:
1166 throw std::runtime_error("Invalid op_type: " + std::to_string(op_type));
1167 }
1168 }
1169
1170 template <
1171 typename TReturn,
1172 typename TInput,
1173 std::enable_if_t<std::is_integral_v<TInput>, int> = 0>
compute_intrinsics(IntrinsicsOp op_type,TInput v)1174 static TReturn compute_intrinsics(IntrinsicsOp op_type, TInput v) {
1175 switch (op_type) {
1176 case kAbs: {
1177 // internal tool complains about calling `abs` on unsigned, the
1178 // following makes the tool happy
1179 using X = std::conditional_t<std::is_unsigned_v<TInput>, int, TInput>;
1180 return std::is_unsigned_v<TInput> ? v : std::abs(static_cast<X>(v));
1181 }
1182 default:
1183 throw std::runtime_error(
1184 "Invalid integral op_type: " + std::to_string(op_type));
1185 }
1186 }
1187
1188 // specialization for float -> int ops (just kIsNan currently)
compute_intrinsics(IntrinsicsOp op_type,float v)1189 int compute_intrinsics(IntrinsicsOp op_type, float v) {
1190 switch (op_type) {
1191 case kIsNan:
1192 return std::isnan(v);
1193 default:
1194 throw std::runtime_error("Invalid op_type: " + std::to_string(op_type));
1195 }
1196 }
1197
1198 template <typename TReturn, typename TInput>
compute_intrinsics(IntrinsicsOp op_type,TInput v1,TInput v2)1199 TReturn compute_intrinsics(IntrinsicsOp op_type, TInput v1, TInput v2) {
1200 switch (op_type) {
1201 case kPow:
1202 return std::pow(v1, v2);
1203 case kFmod:
1204 return std::fmod(v1, v2);
1205 case kRemainder:
1206 return std::remainder(v1, v2);
1207 case kAtan2:
1208 return std::atan2(v1, v2);
1209 default:
1210 throw std::runtime_error("Invalid op_type: " + std::to_string(op_type));
1211 }
1212 }
1213
1214 InterpValue value_;
1215 BlockPtr scope_;
1216 std::unordered_map<ExprPtr, InterpValue> eval_context_;
1217 std::unordered_map<BlockPtr, std::vector<ExprPtr>> var_by_scope_;
1218 std::unordered_map<BufPtr, void*> buffer_mapping_;
1219 std::unordered_map<BufPtr, std::unique_ptr<std::vector<int>>>
1220 internal_buffers_;
1221 std::unordered_map<BufPtr, void*> ext_bufs_free_ptr_;
1222 };
1223
SimpleIREvaluator(StmtPtr stmt,const std::vector<BufferArg> & buffer_args,at::Device device,const std::string & kernel_func_name)1224 SimpleIREvaluator::SimpleIREvaluator(
1225 StmtPtr stmt,
1226 const std::vector<BufferArg>& buffer_args,
1227 at::Device device,
1228 const std::string& kernel_func_name)
1229 : CodeGen(std::move(stmt), buffer_args, device, kernel_func_name) {
1230 impl_ = std::make_unique<SimpleIREvaluatorImpl>();
1231 expand_intrinsics();
1232 }
1233
1234 SimpleIREvaluator::~SimpleIREvaluator() = default;
1235
call(const std::vector<CallArg> & args)1236 void SimpleIREvaluator::call(const std::vector<CallArg>& args) {
1237 std::vector<void*> raw_args(args.size());
1238 for (size_t i = 0; i < args.size(); i++) {
1239 auto const& bufferArg = buffer_args()[i];
1240 auto const& callArg = args[i];
1241 raw_args[i] = argToPtr(bufferArg, callArg);
1242 }
1243 call_raw(raw_args);
1244 }
1245
call_raw(const std::vector<void * > & args)1246 void SimpleIREvaluator::call_raw(const std::vector<void*>& args) {
1247 if (args.size() != buffer_args().size()) {
1248 throw malformed_input("bad args in IREvaluator call");
1249 }
1250 for (const auto i : c10::irange(args.size())) {
1251 bindArg(buffer_args()[i], args[i]);
1252 }
1253 stmt()->accept(&*impl_);
1254 impl_->clear();
1255 }
1256
bindArg(const BufferArg & bufArg,void * data)1257 void SimpleIREvaluator::bindArg(const BufferArg& bufArg, void* data) {
1258 if (!bufArg.isVar()) {
1259 impl_->bindBuf(bufArg.buf(), data);
1260 return;
1261 }
1262
1263 switch (bufArg.dtype().scalar_type()) {
1264 #define TYPE_CASE(Type, Name) \
1265 case ScalarType::Name: { \
1266 Type typed_data; \
1267 memcpy(&typed_data, data, sizeof(Type)); \
1268 impl_->bindVar(bufArg.var(), typed_data); \
1269 break; \
1270 }
1271 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
1272 #undef TYPE_CASE
1273 default:
1274 throw unsupported_dtype();
1275 }
1276 }
1277
bindVar(const VarPtr & v,const ExprPtr & e)1278 void SimpleIREvaluator::bindVar(const VarPtr& v, const ExprPtr& e) {
1279 impl_->bindVar(v, impl_->evaluateExpr(e));
1280 }
1281
value() const1282 InterpValue SimpleIREvaluator::value() const {
1283 return impl_->value();
1284 }
1285
evalInt(ExprPtr e)1286 std::optional<int64_t> evalInt(ExprPtr e) {
1287 try {
1288 return ExprEval<SimpleIREvaluator>(cast<int64_t>(ExprHandle(std::move(e))))
1289 .value<int64_t>();
1290 } catch (std::runtime_error& err) {
1291 return std::nullopt;
1292 }
1293 }
1294
1295 } // namespace torch::jit::tensorexpr
1296