1 #pragma once
2
3 #include <string>
4 #include <utility>
5 #include <vector>
6
7 #include <torch/csrc/jit/tensorexpr/exceptions.h>
8 #include <torch/csrc/jit/tensorexpr/expr.h>
9 #include <torch/csrc/jit/tensorexpr/fwd_decls.h>
10 #include <torch/csrc/jit/tensorexpr/stmt.h>
11
12 #include <ATen/core/ivalue.h>
13
14 namespace torch {
15 namespace jit {
16 namespace tensorexpr {
17
18 enum CompareSelectOperation {
19 kEQ = 0,
20 kGT,
21 kGE,
22 kLT,
23 kLE,
24 kNE,
25 };
26
27 enum CompareSelectBias {
28 kUnbiased,
29 kLikely,
30 kUnlikely,
31 };
32
getPrecedence(IRNodeType ty)33 inline int getPrecedence(IRNodeType ty) {
34 // Match C++ operator precedence rules, since some pretty-print expressions to
35 // C++. SEE: https://en.cppreference.com/w/cpp/language/operator_precedence
36 switch (ty) {
37 case kPrimitive:
38 return 0;
39 case kCast:
40 case kBitCast:
41 return 2;
42 case kAdd:
43 case kSub:
44 return 6;
45 case kMul:
46 case kDiv:
47 case kMod:
48 return 5;
49 case kMax:
50 case kMin:
51 return 99;
52 case kAnd:
53 return 11;
54 case kOr:
55 return 13;
56 case kLshift:
57 case kRshift:
58 return 7;
59 case kXor:
60 return 12;
61 case kCompareSelect:
62 return 16;
63 default:
64 return 99;
65 }
66 }
67
68 class TORCH_API Cast : public ExprNode<Cast> {
69 public:
src_value()70 ExprPtr src_value() const {
71 return src_value_;
72 }
73
set_src_value(ExprPtr src_value)74 void set_src_value(ExprPtr src_value) {
75 src_value_ = std::move(src_value);
76 }
77
make(Dtype dtype,const ExprHandle & src_value)78 static ExprHandle make(Dtype dtype, const ExprHandle& src_value) {
79 return ExprHandle(alloc<Cast>(dtype, src_value.node()));
80 }
Cast(Dtype dtype,ExprPtr src_value)81 Cast(Dtype dtype, ExprPtr src_value)
82 : ExprNodeBase(dtype, kCast), src_value_(std::move(src_value)) {}
83
isConstant()84 bool isConstant() const override {
85 return src_value_->isConstant();
86 }
87
88 private:
89 ExprPtr src_value_;
90 };
91
92 template <typename T>
cast(const ExprHandle & src_value)93 ExprHandle cast(const ExprHandle& src_value) {
94 return Cast::make(Dtype(ToDtype<T>(), src_value.dtype().lanes()), src_value);
95 }
96
97 // This is a bitwise cast, akin to bitcast in LLVM
98 class TORCH_API BitCast : public ExprNode<BitCast> {
99 public:
src_value()100 ExprPtr src_value() const {
101 return src_value_;
102 }
103
set_src_value(ExprPtr src_value)104 void set_src_value(ExprPtr src_value) {
105 src_value_ = std::move(src_value);
106 }
107
make(Dtype dtype,const ExprHandle & src_value)108 static ExprHandle make(Dtype dtype, const ExprHandle& src_value) {
109 return ExprHandle(alloc<BitCast>(dtype, src_value.node()));
110 }
BitCast(Dtype dtype,ExprPtr src_value)111 BitCast(Dtype dtype, ExprPtr src_value)
112 : ExprNodeBase(dtype, kBitCast), src_value_(std::move(src_value)) {
113 TORCH_CHECK(src_value_->dtype().byte_size() == dtype.byte_size());
114 }
115
isConstant()116 bool isConstant() const override {
117 return src_value_->isConstant();
118 }
119
120 private:
121 ExprPtr src_value_;
122 };
123
124 template <typename T>
bitcast(const ExprHandle & src_value)125 ExprHandle bitcast(const ExprHandle& src_value) {
126 return BitCast::make(
127 Dtype(ToDtype<T>(), src_value.dtype().lanes()), src_value);
128 }
129
130 // Represent the expression node for binary operators.
131 // A CRTP pattern to share common code among the operators.
132 template <typename Op>
133 class BinaryOpNode : public ExprNode<Op> {
134 public:
lhs()135 ExprPtr lhs() const {
136 return this->lhs_;
137 }
rhs()138 ExprPtr rhs() const {
139 return this->rhs_;
140 }
141
set_lhs(ExprPtr lhs)142 void set_lhs(ExprPtr lhs) {
143 lhs_ = std::move(lhs);
144 }
145
set_rhs(ExprPtr rhs)146 void set_rhs(ExprPtr rhs) {
147 rhs_ = std::move(rhs);
148 }
149
make(const ExprHandle & lhs,const ExprHandle & rhs)150 static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) {
151 return ExprHandle(alloc<Op>(lhs.node(), rhs.node()));
152 }
153
154 BinaryOpNode(
155 ExprPtr lhs_v,
156 ExprPtr rhs_v,
157 IRNodeType expr_type,
158 ScalarType ret_type = ScalarType::Undefined)
159 : ExprNode<Op>(
160 BinaryOpDtype(lhs_v->dtype(), rhs_v->dtype(), ret_type),
161 expr_type),
162 lhs_(CastIfNeeded(std::move(lhs_v), ExprNode<Op>::dtype())),
163 rhs_(CastIfNeeded(std::move(rhs_v), ExprNode<Op>::dtype())) {}
164
165 private:
CastIfNeeded(ExprPtr expr,Dtype dst_dtype)166 static ExprPtr CastIfNeeded(ExprPtr expr, Dtype dst_dtype) {
167 if (expr->dtype() == dst_dtype) {
168 return expr;
169 }
170 return Cast::make(dst_dtype, ExprHandle(std::move(expr))).node();
171 }
172
173 ExprPtr lhs_;
174 ExprPtr rhs_;
175 };
176
177 namespace detail {
178 template <typename T>
179 void bin_op_deducer(BinaryOpNode<T>);
180 bool bin_op_deducer(...);
181 } // namespace detail
182
183 class TORCH_API Add : public BinaryOpNode<Add> {
184 public:
Add(ExprPtr lhs,ExprPtr rhs)185 Add(ExprPtr lhs, ExprPtr rhs)
186 : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kAdd) {}
187 };
188
189 class TORCH_API Sub : public BinaryOpNode<Sub> {
190 public:
Sub(ExprPtr lhs,ExprPtr rhs)191 Sub(ExprPtr lhs, ExprPtr rhs)
192 : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kSub) {}
193 };
194
195 class TORCH_API Mul : public BinaryOpNode<Mul> {
196 public:
Mul(ExprPtr lhs,ExprPtr rhs)197 Mul(ExprPtr lhs, ExprPtr rhs)
198 : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kMul) {}
199 };
200
201 class TORCH_API Div : public BinaryOpNode<Div> {
202 public:
Div(ExprPtr lhs,ExprPtr rhs)203 Div(ExprPtr lhs, ExprPtr rhs)
204 : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kDiv) {}
205 };
206
207 class TORCH_API Mod : public BinaryOpNode<Mod> {
208 public:
Mod(ExprPtr lhs,ExprPtr rhs)209 Mod(ExprPtr lhs, ExprPtr rhs)
210 : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kMod) {}
211 };
212
213 template <typename Op>
214 class BitwiseOpNode : public BinaryOpNode<Op> {
215 public:
BitwiseOpNode(ExprPtr lhs,ExprPtr rhs,IRNodeType type)216 BitwiseOpNode(ExprPtr lhs, ExprPtr rhs, IRNodeType type)
217 : BinaryOpNode<Op>(std::move(lhs), std::move(rhs), type) {}
218
make(const ExprHandle & lhs,const ExprHandle & rhs)219 static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) {
220 if (!lhs.dtype().is_integral()) {
221 throw unsupported_dtype();
222 }
223 if (lhs.dtype() != rhs.dtype()) {
224 throw malformed_input("lhs/rhs dtype mismatch");
225 }
226 return BinaryOpNode<Op>::make(lhs, rhs);
227 }
228 };
229
230 class TORCH_API And : public BitwiseOpNode<And> {
231 public:
And(ExprPtr lhs,ExprPtr rhs)232 And(ExprPtr lhs, ExprPtr rhs)
233 : BitwiseOpNode(std::move(lhs), std::move(rhs), IRNodeType::kAnd) {}
234 };
235
236 class TORCH_API Or : public BitwiseOpNode<Or> {
237 public:
Or(ExprPtr lhs,ExprPtr rhs)238 Or(ExprPtr lhs, ExprPtr rhs)
239 : BitwiseOpNode(std::move(lhs), std::move(rhs), IRNodeType::kOr) {}
240 };
241
242 class TORCH_API Xor : public BitwiseOpNode<Xor> {
243 public:
Xor(ExprPtr lhs,ExprPtr rhs)244 Xor(ExprPtr lhs, ExprPtr rhs)
245 : BitwiseOpNode(std::move(lhs), std::move(rhs), IRNodeType::kXor) {}
246 };
247
248 class TORCH_API Lshift : public BitwiseOpNode<Lshift> {
249 public:
Lshift(ExprPtr lhs,ExprPtr rhs)250 Lshift(ExprPtr lhs, ExprPtr rhs)
251 : BitwiseOpNode(std::move(lhs), std::move(rhs), IRNodeType::kLshift) {}
252 };
253
254 class TORCH_API Rshift : public BitwiseOpNode<Rshift> {
255 public:
Rshift(ExprPtr lhs,ExprPtr rhs)256 Rshift(ExprPtr lhs, ExprPtr rhs)
257 : BitwiseOpNode(std::move(lhs), std::move(rhs), IRNodeType::kRshift) {}
258 };
259
260 // TODO: add TORCH_API
261 // Currently adding it results in a compilation error on Windows
262 class Max : public BinaryOpNode<Max> {
263 private:
264 bool propagate_nans_;
265
266 public:
Max(ExprPtr lhs,ExprPtr rhs,bool propagate_nans)267 Max(ExprPtr lhs, ExprPtr rhs, bool propagate_nans)
268 : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kMax),
269 propagate_nans_(propagate_nans) {}
270
propagate_nans()271 bool propagate_nans() const {
272 return propagate_nans_;
273 }
274
275 static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) = delete;
make(const ExprHandle & lhs,const ExprHandle & rhs,bool propagate_nans)276 static ExprHandle make(
277 const ExprHandle& lhs,
278 const ExprHandle& rhs,
279 bool propagate_nans) {
280 return ExprHandle(alloc<Max>(lhs.node(), rhs.node(), propagate_nans));
281 }
282 };
283
284 // TODO: add TORCH_API
285 // Currently adding it results in a compilation error on Windows
286 class Min : public BinaryOpNode<Min> {
287 private:
288 bool propagate_nans_;
289
290 public:
Min(ExprPtr lhs,ExprPtr rhs,bool propagate_nans)291 Min(ExprPtr lhs, ExprPtr rhs, bool propagate_nans)
292 : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kMin),
293 propagate_nans_(propagate_nans) {}
294
propagate_nans()295 bool propagate_nans() const {
296 return propagate_nans_;
297 }
298
299 static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) = delete;
make(const ExprHandle & lhs,const ExprHandle & rhs,bool propagate_nans)300 static ExprHandle make(
301 const ExprHandle& lhs,
302 const ExprHandle& rhs,
303 bool propagate_nans) {
304 return ExprHandle(alloc<Min>(lhs.node(), rhs.node(), propagate_nans));
305 }
306 };
307
308 // Encode typed immediate values e.g. IntImm, FloatImm.
309 #define IMM_DECLARE(Type, Name) \
310 class TORCH_API Name##Imm : public ExprNode<Name##Imm> { \
311 public: \
312 Name##Imm(Type value) \
313 : ExprNodeBase(k##Name, kPrimitive), value_(value) {} \
314 bool isConstant() const override { \
315 return true; \
316 } \
317 Type value() const { \
318 return value_; \
319 } \
320 static ExprHandle make(Type value) { \
321 return ExprHandle(alloc<Name##Imm>(value)); \
322 } \
323 \
324 private: \
325 Type value_; \
326 };
327 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_DECLARE);
328 #undef IMM_DECLARE
329
330 // Get immediate by ScalarType.
331 template <typename T>
getImmediateByType(ScalarType immType,T initialVal)332 ExprPtr getImmediateByType(ScalarType immType, T initialVal) {
333 switch (immType) {
334 #define TYPE_CASE(Type, Name) \
335 case ScalarType::Name: \
336 return alloc<Name##Imm>(Type(initialVal));
337 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
338 #undef TYPE_CASE
339 default:
340 throw unsupported_dtype();
341 }
342 return nullptr;
343 }
344
345 template <typename T>
getImmediateByType(Dtype dtype,T initialVal)346 ExprPtr getImmediateByType(Dtype dtype, T initialVal) {
347 return getImmediateByType<T>(dtype.scalar_type(), initialVal);
348 }
349
350 template <typename T>
immLike(const ExprPtr & e,T v)351 ExprPtr immLike(const ExprPtr& e, T v) {
352 return getImmediateByType<T>(e->dtype(), v);
353 }
354
355 template <typename T>
immLike(const ExprHandle & e,T v)356 ExprPtr immLike(const ExprHandle& e, T v) {
357 return immLike(e.node(), v);
358 }
359
intValue(const ExprPtr & e)360 inline std::optional<int64_t> intValue(const ExprPtr& e) {
361 #define TYPE_CASE(Type, Name) \
362 if (auto v = to<Name##Imm>(e)) { \
363 return v->value(); \
364 }
365 AT_FORALL_INT_TYPES(TYPE_CASE);
366 #undef TYPE_CASE
367 return std::nullopt;
368 }
369
intValue(const ExprHandle & e)370 inline std::optional<int64_t> intValue(const ExprHandle& e) {
371 return intValue(e.node());
372 }
373
374 template <typename T>
immediateAs(const ExprPtr & e)375 T immediateAs(const ExprPtr& e) {
376 #define TYPE_CASE(Type, Name) \
377 if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
378 return imm->value(); \
379 }
380 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
381 #undef TYPE_CASE
382 throw unsupported_dtype();
383 return 0;
384 }
385
386 template <typename T>
immediateAs(const ExprHandle & e)387 T immediateAs(const ExprHandle& e) {
388 return immediateAs<T>(e.node());
389 }
390
391 template <typename T>
immediateEquals(const ExprPtr & e,T val)392 bool immediateEquals(const ExprPtr& e, T val) {
393 #define TYPE_CASE(Type, Name) \
394 if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
395 return imm->value() == val; \
396 }
397 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
398 #undef TYPE_CASE
399 throw unsupported_dtype();
400 return false;
401 }
402
403 TORCH_API bool immediateIsNegative(const ExprPtr& e);
404
405 TORCH_API bool immediateIsPositive(const ExprPtr& e);
406
407 TORCH_API bool immediateIsZero(const ExprPtr& e);
408
409 // Represents a ramp vector node:
410 // [base, base + 1 * stride, ... , base + (lanes - 1) * stride]
411 class TORCH_API Ramp : public ExprNode<Ramp> {
412 public:
base()413 ExprPtr base() const {
414 return base_;
415 }
stride()416 ExprPtr stride() const {
417 return stride_;
418 }
419
set_base(ExprPtr base)420 void set_base(ExprPtr base) {
421 base_ = std::move(base);
422 }
423
set_stride(ExprPtr stride)424 void set_stride(ExprPtr stride) {
425 stride_ = std::move(stride);
426 }
427
make(const ExprHandle & base,const ExprHandle & stride,int64_t lanes)428 static ExprHandle make(
429 const ExprHandle& base,
430 const ExprHandle& stride,
431 int64_t lanes) {
432 if (stride.dtype() != base.dtype()) {
433 throw malformed_input("Bad stride in Ramp");
434 }
435 return ExprHandle(alloc<Ramp>(base.node(), stride.node(), lanes));
436 }
lanes()437 int64_t lanes() const {
438 return lanes_;
439 }
440
Ramp(ExprPtr base,ExprPtr stride,int64_t lanes)441 Ramp(ExprPtr base, ExprPtr stride, int64_t lanes)
442 : ExprNodeBase(Dtype(base->dtype(), lanes)),
443 base_(std::move(base)),
444 stride_(std::move(stride)),
445 lanes_(lanes) {}
446
447 private:
448 ExprPtr base_;
449 ExprPtr stride_;
450 int64_t lanes_;
451 };
452
453 class TORCH_API Load : public ExprNode<Load> {
454 public:
base_handle()455 VarPtr base_handle() const {
456 return buf_->base_handle();
457 }
indices()458 std::vector<ExprPtr> indices() const {
459 return indices_;
460 }
flat_index()461 ExprPtr flat_index() const {
462 TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened.");
463 return indices_[0];
464 }
buf()465 BufPtr buf() const {
466 return buf_;
467 }
468
set_buf(BufPtr buf)469 void set_buf(BufPtr buf) {
470 buf_ = std::move(buf);
471 }
472
set_indices(std::vector<ExprPtr> indices)473 void set_indices(std::vector<ExprPtr> indices) {
474 indices_ = std::move(indices);
475 }
476
477 static ExprHandle make(
478 Dtype dtype,
479 const BufHandle& buf,
480 const std::vector<ExprHandle>& indices);
481 static ExprHandle make(
482 const BufHandle& buf,
483 const std::vector<ExprHandle>& indices);
484
485 Load(Dtype dtype, BufPtr base_handle, std::vector<ExprPtr> indices);
486 Load(const BufPtr& base_handle, const std::vector<ExprPtr>& indices);
487
488 private:
489 BufPtr buf_;
490 std::vector<ExprPtr> indices_;
491 };
492
493 class TORCH_API Broadcast : public ExprNode<Broadcast> {
494 public:
value()495 ExprPtr value() const {
496 return value_;
497 }
498
set_value(ExprPtr value)499 void set_value(ExprPtr value) {
500 value_ = std::move(value);
501 }
502
lanes()503 int64_t lanes() const {
504 return lanes_;
505 }
make(const ExprHandle & value,int64_t lanes)506 static ExprHandle make(const ExprHandle& value, int64_t lanes) {
507 return ExprHandle(alloc<Broadcast>(value.node(), lanes));
508 }
Broadcast(ExprPtr value,int64_t lanes)509 Broadcast(ExprPtr value, int64_t lanes)
510 : ExprNodeBase(Dtype(value->dtype(), lanes)),
511 value_(std::move(value)),
512 lanes_(lanes) {}
513
514 private:
515 ExprPtr value_;
516 int64_t lanes_;
517 };
518
519 class TORCH_API IfThenElse : public ExprNode<IfThenElse> {
520 public:
condition()521 ExprPtr condition() const {
522 return condition_;
523 }
524
525 // Lazily evaluated only if condition is true
true_value()526 ExprPtr true_value() const {
527 return true_;
528 }
529
530 // Lazily evaluated only if condition is false
false_value()531 ExprPtr false_value() const {
532 return false_;
533 }
534
set_condition(ExprPtr condition)535 void set_condition(ExprPtr condition) {
536 condition_ = std::move(condition);
537 }
538
set_true_value(ExprPtr true_value)539 void set_true_value(ExprPtr true_value) {
540 true_ = std::move(true_value);
541 }
542
set_false_value(ExprPtr false_value)543 void set_false_value(ExprPtr false_value) {
544 false_ = std::move(false_value);
545 }
546
make(const ExprHandle & c,const ExprHandle & t,const ExprHandle & f)547 static ExprHandle make(
548 const ExprHandle& c,
549 const ExprHandle& t,
550 const ExprHandle& f) {
551 if (!c.dtype().is_integral()) {
552 throw unsupported_dtype();
553 }
554 if (c.dtype().lanes() != 1) {
555 throw unsupported_dtype();
556 }
557 if (t.dtype() != f.dtype()) {
558 throw malformed_input("Bad dtype in IfThenElse");
559 }
560 return ExprHandle(alloc<IfThenElse>(c.node(), t.node(), f.node()));
561 }
562
IfThenElse(ExprPtr c,ExprPtr t,ExprPtr f)563 IfThenElse(ExprPtr c, ExprPtr t, ExprPtr f)
564 : ExprNodeBase(t->dtype()),
565 condition_(std::move(c)),
566 true_(std::move(t)),
567 false_(std::move(f)) {}
568
569 private:
570 ExprPtr condition_;
571 ExprPtr true_;
572 ExprPtr false_;
573 };
574
575 class TORCH_API CompareSelect : public ExprNode<CompareSelect> {
576 public:
compare_select_op()577 CompareSelectOperation compare_select_op() const {
578 return compare_op_;
579 }
lhs()580 ExprPtr lhs() const {
581 return this->lhs_;
582 }
rhs()583 ExprPtr rhs() const {
584 return this->rhs_;
585 }
ret_val1()586 ExprPtr ret_val1() const {
587 return this->ret_val1_;
588 }
ret_val2()589 ExprPtr ret_val2() const {
590 return this->ret_val2_;
591 }
592
set_lhs(ExprPtr lhs)593 void set_lhs(ExprPtr lhs) {
594 lhs_ = std::move(lhs);
595 }
596
set_rhs(ExprPtr rhs)597 void set_rhs(ExprPtr rhs) {
598 rhs_ = std::move(rhs);
599 }
600
set_ret_val1(ExprPtr ret_val1)601 void set_ret_val1(ExprPtr ret_val1) {
602 ret_val1_ = std::move(ret_val1);
603 }
604
set_ret_val2(ExprPtr ret_val2)605 void set_ret_val2(ExprPtr ret_val2) {
606 ret_val2_ = std::move(ret_val2);
607 }
608
bias()609 CompareSelectBias bias() const {
610 return bias_;
611 }
612
613 static ExprHandle make(
614 const ExprHandle& lhs,
615 const ExprHandle& rhs,
616 CompareSelectOperation cmp_op,
617 CompareSelectBias bias = kUnbiased) {
618 if (lhs.dtype() != rhs.dtype()) {
619 throw malformed_input("bad dtype in CompareSelect");
620 }
621 return ExprHandle(alloc<CompareSelect>(
622 lhs.node(),
623 rhs.node(),
624 IntImm::make(1).node(),
625 IntImm::make(0).node(),
626 cmp_op,
627 bias));
628 }
629
630 static ExprHandle make(
631 const ExprHandle& lhs,
632 const ExprHandle& rhs,
633 const ExprHandle& ret_val1,
634 const ExprHandle& ret_val2,
635 CompareSelectOperation cmp_op,
636 CompareSelectBias bias = kUnbiased) {
637 if (lhs.dtype() != rhs.dtype() || ret_val1.dtype() != ret_val2.dtype()) {
638 throw malformed_input("bad dtype in CompareSelect");
639 }
640 return ExprHandle(alloc<CompareSelect>(
641 lhs.node(),
642 rhs.node(),
643 ret_val1.node(),
644 ret_val2.node(),
645 cmp_op,
646 bias));
647 }
648
649 CompareSelect(
650 ExprPtr lhs,
651 ExprPtr rhs,
652 ExprPtr ret_val1,
653 ExprPtr ret_val2,
654 CompareSelectOperation cmp_op,
655 CompareSelectBias bias = kUnbiased)
656 : ExprNodeBase(ret_val1->dtype()),
657 lhs_(std::move(lhs)),
658 rhs_(std::move(rhs)),
659 ret_val1_(std::move(ret_val1)),
660 ret_val2_(std::move(ret_val2)),
661 compare_op_(cmp_op),
662 bias_(bias) {}
663
664 CompareSelect(
665 ExprPtr lhs,
666 ExprPtr rhs,
667 CompareSelectOperation cmp_op,
668 CompareSelectBias bias = kUnbiased)
ExprNodeBase(kInt)669 : ExprNodeBase(kInt),
670 lhs_(std::move(lhs)),
671 rhs_(std::move(rhs)),
672 ret_val1_(alloc<IntImm>(1)),
673 ret_val2_(alloc<IntImm>(0)),
674 compare_op_(cmp_op),
675 bias_(bias) {}
676
677 private:
678 ExprPtr lhs_;
679 ExprPtr rhs_;
680 ExprPtr ret_val1_;
681 ExprPtr ret_val2_;
682 CompareSelectOperation compare_op_;
683 CompareSelectBias bias_;
684 };
685
686 enum IntrinsicsOp {
687 kSin,
688 kCos,
689 kTan,
690 kAsin,
691 kAcos,
692 kAtan,
693 kAtan2,
694 kSinh,
695 kCosh,
696 kTanh,
697 kSigmoid,
698 kExp,
699 kExpm1,
700 kAbs,
701 kLog,
702 kLog2,
703 kLog10,
704 kLog1p,
705 kErf,
706 kErfc,
707 kSqrt,
708 kRsqrt,
709 kPow,
710 kCeil,
711 kFloor,
712 kRound,
713 kTrunc,
714 kFmod,
715 kRemainder,
716 kLgamma,
717 kFrac,
718 kIsNan,
719 kRand, // We need more discussions on this. Should we consider stateful?
720 kMaxIntrinsicsOp,
721 };
722
723 class TORCH_API Intrinsics : public ExprNode<Intrinsics> {
724 public:
make(IntrinsicsOp op_type,const ExprHandle & v1)725 static ExprHandle make(IntrinsicsOp op_type, const ExprHandle& v1) {
726 return ExprHandle(alloc<Intrinsics>(op_type, v1.node()));
727 }
728
make(IntrinsicsOp op_type,const ExprHandle & v1,const ExprHandle & v2)729 static ExprHandle make(
730 IntrinsicsOp op_type,
731 const ExprHandle& v1,
732 const ExprHandle& v2) {
733 return ExprHandle(alloc<Intrinsics>(op_type, v1.node(), v2.node()));
734 }
735
make(IntrinsicsOp op_type,const std::vector<ExprHandle> & params)736 static ExprHandle make(
737 IntrinsicsOp op_type,
738 const std::vector<ExprHandle>& params) {
739 std::vector<ExprPtr> params_nodes(params.size());
740 for (size_t i = 0; i < params.size(); i++) {
741 params_nodes[i] = params[i].node();
742 }
743 return ExprHandle(alloc<Intrinsics>(op_type, params_nodes));
744 }
745
make(IntrinsicsOp op_type,Dtype dtype)746 static ExprHandle make(IntrinsicsOp op_type, Dtype dtype) {
747 return ExprHandle(alloc<Intrinsics>(op_type, dtype));
748 }
749
op_type()750 IntrinsicsOp op_type() const {
751 return op_type_;
752 }
753
func_name()754 std::string func_name() const {
755 switch (op_type()) {
756 case kSin:
757 return "sin";
758 case kCos:
759 return "cos";
760 case kTan:
761 return "tan";
762 case kAsin:
763 return "asin";
764 case kAcos:
765 return "acos";
766 case kAtan:
767 return "atan";
768 case kAtan2:
769 return "atan2";
770 case kSinh:
771 return "sinh";
772 case kCosh:
773 return "cosh";
774 case kTanh:
775 return "tanh";
776 case kSigmoid:
777 return "sigmoid";
778 case kExp:
779 return "exp";
780 case kAbs:
781 return "abs";
782 case kLog:
783 return "log";
784 case kLog2:
785 return "log2";
786 case kLog10:
787 return "log10";
788 case kLog1p:
789 return "log1p";
790 case kErf:
791 return "erf";
792 case kSqrt:
793 return "sqrt";
794 case kRsqrt:
795 return "rsqrt";
796 case kPow:
797 return "pow";
798 case kCeil:
799 return "ceil";
800 case kFloor:
801 return "floor";
802 case kRound:
803 return "round";
804 case kTrunc:
805 return "trunc";
806 case kRand:
807 return "rand";
808 case kFmod:
809 return "fmod";
810 case kRemainder:
811 return "remainder";
812 case kLgamma:
813 return "lgamma";
814 case kExpm1:
815 return "expm1";
816 case kErfc:
817 return "erfc";
818 case kFrac:
819 return "frac";
820 case kIsNan:
821 return "isnan";
822 default:
823 throw std::runtime_error(
824 "invalid op_type: " + std::to_string(op_type()));
825 }
826 }
827
Intrinsics(IntrinsicsOp op_type,Dtype dtype)828 Intrinsics(IntrinsicsOp op_type, Dtype dtype)
829 : ExprNodeBase(IntrinsicsDtype(op_type, dtype)),
830 params_({}),
831 op_type_(op_type) {
832 if (OpArgCount(op_type) != 0) {
833 throw malformed_input("bad arg count in Intrinsics");
834 }
835 }
836
Intrinsics(IntrinsicsOp op_type,ExprPtr v1)837 Intrinsics(IntrinsicsOp op_type, ExprPtr v1)
838 : ExprNodeBase(IntrinsicsDtype(op_type, v1->dtype())),
839 params_({std::move(v1)}),
840 op_type_(op_type) {
841 if (OpArgCount(op_type) != 1) {
842 throw malformed_input("bad arg count in Intrinsics");
843 }
844 }
845
Intrinsics(IntrinsicsOp op_type,ExprPtr v1,ExprPtr v2)846 Intrinsics(IntrinsicsOp op_type, ExprPtr v1, ExprPtr v2)
847 : ExprNodeBase(IntrinsicsDtype(op_type, v1->dtype(), v2->dtype())),
848 params_({std::move(v1), std::move(v2)}),
849 op_type_(op_type) {
850 if (OpArgCount(op_type) != 2) {
851 throw malformed_input("bad arg count in Intrinsics");
852 }
853 }
854
Intrinsics(IntrinsicsOp op_type,const std::vector<ExprPtr> & params)855 Intrinsics(IntrinsicsOp op_type, const std::vector<ExprPtr>& params)
856 : ExprNodeBase(IntrinsicsDtype(op_type, params)),
857 params_(params),
858 op_type_(op_type) {
859 if (OpArgCount(op_type) != nparams()) {
860 throw malformed_input("bad arg count in Intrinsics");
861 }
862 }
863
Intrinsics(IntrinsicsOp op_type,Dtype dtype,const std::vector<ExprPtr> & params)864 Intrinsics(
865 IntrinsicsOp op_type,
866 Dtype dtype,
867 const std::vector<ExprPtr>& params)
868 : ExprNodeBase(IntrinsicsDtype(op_type, dtype)),
869 params_(params),
870 op_type_(op_type) {
871 if (OpArgCount(op_type) != nparams()) {
872 throw malformed_input("bad arg count in Intrinsics");
873 }
874 }
875
isPure()876 bool isPure() const {
877 return op_type_ != kRand;
878 }
879
nparams()880 size_t nparams() const {
881 return params_.size();
882 }
883
param(size_t index)884 ExprPtr param(size_t index) const {
885 return params_[index];
886 }
params()887 const std::vector<ExprPtr>& params() const {
888 return params_;
889 }
890
set_params(std::vector<ExprPtr> params)891 void set_params(std::vector<ExprPtr> params) {
892 params_ = std::move(params);
893 }
894
895 static size_t OpArgCount(IntrinsicsOp op_type);
896
897 private:
898 static Dtype IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1);
899 static Dtype IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1, Dtype dt2);
900 static Dtype IntrinsicsDtype(
901 IntrinsicsOp op_type,
902 const std::vector<ExprPtr>& params);
903
904 std::vector<ExprPtr> params_;
905 IntrinsicsOp op_type_;
906 };
907
908 TORCH_API std::vector<ExprPtr> ExprHandleVectorToExprVector(
909 const std::vector<ExprHandle>&);
910 TORCH_API std::vector<ExprHandle> ExprVectorToExprHandleVector(
911 const std::vector<ExprPtr>&);
912 TORCH_API std::vector<VarPtr> VarHandleVectorToVarVector(
913 const std::vector<VarHandle>&);
914 TORCH_API std::vector<VarHandle> VarVectorToVarHandleVector(
915 const std::vector<VarPtr>&);
916 TORCH_API ExprPtr flatten_index(
917 const std::vector<ExprPtr>& dims,
918 const std::vector<ExprPtr>& indices,
919 const std::vector<ExprPtr>& strides);
920
921 } // namespace tensorexpr
922 } // namespace jit
923 } // namespace torch
924