xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/ir.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <string>
4 #include <utility>
5 #include <vector>
6 
7 #include <torch/csrc/jit/tensorexpr/exceptions.h>
8 #include <torch/csrc/jit/tensorexpr/expr.h>
9 #include <torch/csrc/jit/tensorexpr/fwd_decls.h>
10 #include <torch/csrc/jit/tensorexpr/stmt.h>
11 
12 #include <ATen/core/ivalue.h>
13 
14 namespace torch {
15 namespace jit {
16 namespace tensorexpr {
17 
18 enum CompareSelectOperation {
19   kEQ = 0,
20   kGT,
21   kGE,
22   kLT,
23   kLE,
24   kNE,
25 };
26 
27 enum CompareSelectBias {
28   kUnbiased,
29   kLikely,
30   kUnlikely,
31 };
32 
getPrecedence(IRNodeType ty)33 inline int getPrecedence(IRNodeType ty) {
34   // Match C++ operator precedence rules, since some pretty-print expressions to
35   // C++. SEE: https://en.cppreference.com/w/cpp/language/operator_precedence
36   switch (ty) {
37     case kPrimitive:
38       return 0;
39     case kCast:
40     case kBitCast:
41       return 2;
42     case kAdd:
43     case kSub:
44       return 6;
45     case kMul:
46     case kDiv:
47     case kMod:
48       return 5;
49     case kMax:
50     case kMin:
51       return 99;
52     case kAnd:
53       return 11;
54     case kOr:
55       return 13;
56     case kLshift:
57     case kRshift:
58       return 7;
59     case kXor:
60       return 12;
61     case kCompareSelect:
62       return 16;
63     default:
64       return 99;
65   }
66 }
67 
68 class TORCH_API Cast : public ExprNode<Cast> {
69  public:
src_value()70   ExprPtr src_value() const {
71     return src_value_;
72   }
73 
set_src_value(ExprPtr src_value)74   void set_src_value(ExprPtr src_value) {
75     src_value_ = std::move(src_value);
76   }
77 
make(Dtype dtype,const ExprHandle & src_value)78   static ExprHandle make(Dtype dtype, const ExprHandle& src_value) {
79     return ExprHandle(alloc<Cast>(dtype, src_value.node()));
80   }
Cast(Dtype dtype,ExprPtr src_value)81   Cast(Dtype dtype, ExprPtr src_value)
82       : ExprNodeBase(dtype, kCast), src_value_(std::move(src_value)) {}
83 
isConstant()84   bool isConstant() const override {
85     return src_value_->isConstant();
86   }
87 
88  private:
89   ExprPtr src_value_;
90 };
91 
92 template <typename T>
cast(const ExprHandle & src_value)93 ExprHandle cast(const ExprHandle& src_value) {
94   return Cast::make(Dtype(ToDtype<T>(), src_value.dtype().lanes()), src_value);
95 }
96 
97 // This is a bitwise cast, akin to bitcast in LLVM
98 class TORCH_API BitCast : public ExprNode<BitCast> {
99  public:
src_value()100   ExprPtr src_value() const {
101     return src_value_;
102   }
103 
set_src_value(ExprPtr src_value)104   void set_src_value(ExprPtr src_value) {
105     src_value_ = std::move(src_value);
106   }
107 
make(Dtype dtype,const ExprHandle & src_value)108   static ExprHandle make(Dtype dtype, const ExprHandle& src_value) {
109     return ExprHandle(alloc<BitCast>(dtype, src_value.node()));
110   }
BitCast(Dtype dtype,ExprPtr src_value)111   BitCast(Dtype dtype, ExprPtr src_value)
112       : ExprNodeBase(dtype, kBitCast), src_value_(std::move(src_value)) {
113     TORCH_CHECK(src_value_->dtype().byte_size() == dtype.byte_size());
114   }
115 
isConstant()116   bool isConstant() const override {
117     return src_value_->isConstant();
118   }
119 
120  private:
121   ExprPtr src_value_;
122 };
123 
124 template <typename T>
bitcast(const ExprHandle & src_value)125 ExprHandle bitcast(const ExprHandle& src_value) {
126   return BitCast::make(
127       Dtype(ToDtype<T>(), src_value.dtype().lanes()), src_value);
128 }
129 
130 // Represent the expression node for binary operators.
131 // A CRTP pattern to share common code among the operators.
132 template <typename Op>
133 class BinaryOpNode : public ExprNode<Op> {
134  public:
lhs()135   ExprPtr lhs() const {
136     return this->lhs_;
137   }
rhs()138   ExprPtr rhs() const {
139     return this->rhs_;
140   }
141 
set_lhs(ExprPtr lhs)142   void set_lhs(ExprPtr lhs) {
143     lhs_ = std::move(lhs);
144   }
145 
set_rhs(ExprPtr rhs)146   void set_rhs(ExprPtr rhs) {
147     rhs_ = std::move(rhs);
148   }
149 
make(const ExprHandle & lhs,const ExprHandle & rhs)150   static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) {
151     return ExprHandle(alloc<Op>(lhs.node(), rhs.node()));
152   }
153 
154   BinaryOpNode(
155       ExprPtr lhs_v,
156       ExprPtr rhs_v,
157       IRNodeType expr_type,
158       ScalarType ret_type = ScalarType::Undefined)
159       : ExprNode<Op>(
160             BinaryOpDtype(lhs_v->dtype(), rhs_v->dtype(), ret_type),
161             expr_type),
162         lhs_(CastIfNeeded(std::move(lhs_v), ExprNode<Op>::dtype())),
163         rhs_(CastIfNeeded(std::move(rhs_v), ExprNode<Op>::dtype())) {}
164 
165  private:
CastIfNeeded(ExprPtr expr,Dtype dst_dtype)166   static ExprPtr CastIfNeeded(ExprPtr expr, Dtype dst_dtype) {
167     if (expr->dtype() == dst_dtype) {
168       return expr;
169     }
170     return Cast::make(dst_dtype, ExprHandle(std::move(expr))).node();
171   }
172 
173   ExprPtr lhs_;
174   ExprPtr rhs_;
175 };
176 
177 namespace detail {
178 template <typename T>
179 void bin_op_deducer(BinaryOpNode<T>);
180 bool bin_op_deducer(...);
181 } // namespace detail
182 
183 class TORCH_API Add : public BinaryOpNode<Add> {
184  public:
Add(ExprPtr lhs,ExprPtr rhs)185   Add(ExprPtr lhs, ExprPtr rhs)
186       : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kAdd) {}
187 };
188 
189 class TORCH_API Sub : public BinaryOpNode<Sub> {
190  public:
Sub(ExprPtr lhs,ExprPtr rhs)191   Sub(ExprPtr lhs, ExprPtr rhs)
192       : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kSub) {}
193 };
194 
195 class TORCH_API Mul : public BinaryOpNode<Mul> {
196  public:
Mul(ExprPtr lhs,ExprPtr rhs)197   Mul(ExprPtr lhs, ExprPtr rhs)
198       : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kMul) {}
199 };
200 
201 class TORCH_API Div : public BinaryOpNode<Div> {
202  public:
Div(ExprPtr lhs,ExprPtr rhs)203   Div(ExprPtr lhs, ExprPtr rhs)
204       : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kDiv) {}
205 };
206 
207 class TORCH_API Mod : public BinaryOpNode<Mod> {
208  public:
Mod(ExprPtr lhs,ExprPtr rhs)209   Mod(ExprPtr lhs, ExprPtr rhs)
210       : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kMod) {}
211 };
212 
213 template <typename Op>
214 class BitwiseOpNode : public BinaryOpNode<Op> {
215  public:
BitwiseOpNode(ExprPtr lhs,ExprPtr rhs,IRNodeType type)216   BitwiseOpNode(ExprPtr lhs, ExprPtr rhs, IRNodeType type)
217       : BinaryOpNode<Op>(std::move(lhs), std::move(rhs), type) {}
218 
make(const ExprHandle & lhs,const ExprHandle & rhs)219   static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) {
220     if (!lhs.dtype().is_integral()) {
221       throw unsupported_dtype();
222     }
223     if (lhs.dtype() != rhs.dtype()) {
224       throw malformed_input("lhs/rhs dtype mismatch");
225     }
226     return BinaryOpNode<Op>::make(lhs, rhs);
227   }
228 };
229 
230 class TORCH_API And : public BitwiseOpNode<And> {
231  public:
And(ExprPtr lhs,ExprPtr rhs)232   And(ExprPtr lhs, ExprPtr rhs)
233       : BitwiseOpNode(std::move(lhs), std::move(rhs), IRNodeType::kAnd) {}
234 };
235 
236 class TORCH_API Or : public BitwiseOpNode<Or> {
237  public:
Or(ExprPtr lhs,ExprPtr rhs)238   Or(ExprPtr lhs, ExprPtr rhs)
239       : BitwiseOpNode(std::move(lhs), std::move(rhs), IRNodeType::kOr) {}
240 };
241 
242 class TORCH_API Xor : public BitwiseOpNode<Xor> {
243  public:
Xor(ExprPtr lhs,ExprPtr rhs)244   Xor(ExprPtr lhs, ExprPtr rhs)
245       : BitwiseOpNode(std::move(lhs), std::move(rhs), IRNodeType::kXor) {}
246 };
247 
248 class TORCH_API Lshift : public BitwiseOpNode<Lshift> {
249  public:
Lshift(ExprPtr lhs,ExprPtr rhs)250   Lshift(ExprPtr lhs, ExprPtr rhs)
251       : BitwiseOpNode(std::move(lhs), std::move(rhs), IRNodeType::kLshift) {}
252 };
253 
254 class TORCH_API Rshift : public BitwiseOpNode<Rshift> {
255  public:
Rshift(ExprPtr lhs,ExprPtr rhs)256   Rshift(ExprPtr lhs, ExprPtr rhs)
257       : BitwiseOpNode(std::move(lhs), std::move(rhs), IRNodeType::kRshift) {}
258 };
259 
260 // TODO: add TORCH_API
261 // Currently adding it results in a compilation error on Windows
262 class Max : public BinaryOpNode<Max> {
263  private:
264   bool propagate_nans_;
265 
266  public:
Max(ExprPtr lhs,ExprPtr rhs,bool propagate_nans)267   Max(ExprPtr lhs, ExprPtr rhs, bool propagate_nans)
268       : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kMax),
269         propagate_nans_(propagate_nans) {}
270 
propagate_nans()271   bool propagate_nans() const {
272     return propagate_nans_;
273   }
274 
275   static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) = delete;
make(const ExprHandle & lhs,const ExprHandle & rhs,bool propagate_nans)276   static ExprHandle make(
277       const ExprHandle& lhs,
278       const ExprHandle& rhs,
279       bool propagate_nans) {
280     return ExprHandle(alloc<Max>(lhs.node(), rhs.node(), propagate_nans));
281   }
282 };
283 
284 // TODO: add TORCH_API
285 // Currently adding it results in a compilation error on Windows
286 class Min : public BinaryOpNode<Min> {
287  private:
288   bool propagate_nans_;
289 
290  public:
Min(ExprPtr lhs,ExprPtr rhs,bool propagate_nans)291   Min(ExprPtr lhs, ExprPtr rhs, bool propagate_nans)
292       : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kMin),
293         propagate_nans_(propagate_nans) {}
294 
propagate_nans()295   bool propagate_nans() const {
296     return propagate_nans_;
297   }
298 
299   static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) = delete;
make(const ExprHandle & lhs,const ExprHandle & rhs,bool propagate_nans)300   static ExprHandle make(
301       const ExprHandle& lhs,
302       const ExprHandle& rhs,
303       bool propagate_nans) {
304     return ExprHandle(alloc<Min>(lhs.node(), rhs.node(), propagate_nans));
305   }
306 };
307 
308 // Encode typed immediate values e.g. IntImm, FloatImm.
309 #define IMM_DECLARE(Type, Name)                               \
310   class TORCH_API Name##Imm : public ExprNode<Name##Imm> {    \
311    public:                                                    \
312     Name##Imm(Type value)                                     \
313         : ExprNodeBase(k##Name, kPrimitive), value_(value) {} \
314     bool isConstant() const override {                        \
315       return true;                                            \
316     }                                                         \
317     Type value() const {                                      \
318       return value_;                                          \
319     }                                                         \
320     static ExprHandle make(Type value) {                      \
321       return ExprHandle(alloc<Name##Imm>(value));             \
322     }                                                         \
323                                                               \
324    private:                                                   \
325     Type value_;                                              \
326   };
327 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_DECLARE);
328 #undef IMM_DECLARE
329 
330 // Get immediate by ScalarType.
331 template <typename T>
getImmediateByType(ScalarType immType,T initialVal)332 ExprPtr getImmediateByType(ScalarType immType, T initialVal) {
333   switch (immType) {
334 #define TYPE_CASE(Type, Name) \
335   case ScalarType::Name:      \
336     return alloc<Name##Imm>(Type(initialVal));
337     AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
338 #undef TYPE_CASE
339     default:
340       throw unsupported_dtype();
341   }
342   return nullptr;
343 }
344 
345 template <typename T>
getImmediateByType(Dtype dtype,T initialVal)346 ExprPtr getImmediateByType(Dtype dtype, T initialVal) {
347   return getImmediateByType<T>(dtype.scalar_type(), initialVal);
348 }
349 
350 template <typename T>
immLike(const ExprPtr & e,T v)351 ExprPtr immLike(const ExprPtr& e, T v) {
352   return getImmediateByType<T>(e->dtype(), v);
353 }
354 
355 template <typename T>
immLike(const ExprHandle & e,T v)356 ExprPtr immLike(const ExprHandle& e, T v) {
357   return immLike(e.node(), v);
358 }
359 
intValue(const ExprPtr & e)360 inline std::optional<int64_t> intValue(const ExprPtr& e) {
361 #define TYPE_CASE(Type, Name)      \
362   if (auto v = to<Name##Imm>(e)) { \
363     return v->value();             \
364   }
365   AT_FORALL_INT_TYPES(TYPE_CASE);
366 #undef TYPE_CASE
367   return std::nullopt;
368 }
369 
intValue(const ExprHandle & e)370 inline std::optional<int64_t> intValue(const ExprHandle& e) {
371   return intValue(e.node());
372 }
373 
374 template <typename T>
immediateAs(const ExprPtr & e)375 T immediateAs(const ExprPtr& e) {
376 #define TYPE_CASE(Type, Name)                \
377   if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
378     return imm->value();                     \
379   }
380   AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
381 #undef TYPE_CASE
382   throw unsupported_dtype();
383   return 0;
384 }
385 
386 template <typename T>
immediateAs(const ExprHandle & e)387 T immediateAs(const ExprHandle& e) {
388   return immediateAs<T>(e.node());
389 }
390 
391 template <typename T>
immediateEquals(const ExprPtr & e,T val)392 bool immediateEquals(const ExprPtr& e, T val) {
393 #define TYPE_CASE(Type, Name)                \
394   if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
395     return imm->value() == val;              \
396   }
397   AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
398 #undef TYPE_CASE
399   throw unsupported_dtype();
400   return false;
401 }
402 
403 TORCH_API bool immediateIsNegative(const ExprPtr& e);
404 
405 TORCH_API bool immediateIsPositive(const ExprPtr& e);
406 
407 TORCH_API bool immediateIsZero(const ExprPtr& e);
408 
409 // Represents a ramp vector node:
410 //     [base, base + 1 * stride, ... , base + (lanes - 1) * stride]
411 class TORCH_API Ramp : public ExprNode<Ramp> {
412  public:
base()413   ExprPtr base() const {
414     return base_;
415   }
stride()416   ExprPtr stride() const {
417     return stride_;
418   }
419 
set_base(ExprPtr base)420   void set_base(ExprPtr base) {
421     base_ = std::move(base);
422   }
423 
set_stride(ExprPtr stride)424   void set_stride(ExprPtr stride) {
425     stride_ = std::move(stride);
426   }
427 
make(const ExprHandle & base,const ExprHandle & stride,int64_t lanes)428   static ExprHandle make(
429       const ExprHandle& base,
430       const ExprHandle& stride,
431       int64_t lanes) {
432     if (stride.dtype() != base.dtype()) {
433       throw malformed_input("Bad stride in Ramp");
434     }
435     return ExprHandle(alloc<Ramp>(base.node(), stride.node(), lanes));
436   }
lanes()437   int64_t lanes() const {
438     return lanes_;
439   }
440 
Ramp(ExprPtr base,ExprPtr stride,int64_t lanes)441   Ramp(ExprPtr base, ExprPtr stride, int64_t lanes)
442       : ExprNodeBase(Dtype(base->dtype(), lanes)),
443         base_(std::move(base)),
444         stride_(std::move(stride)),
445         lanes_(lanes) {}
446 
447  private:
448   ExprPtr base_;
449   ExprPtr stride_;
450   int64_t lanes_;
451 };
452 
453 class TORCH_API Load : public ExprNode<Load> {
454  public:
base_handle()455   VarPtr base_handle() const {
456     return buf_->base_handle();
457   }
indices()458   std::vector<ExprPtr> indices() const {
459     return indices_;
460   }
flat_index()461   ExprPtr flat_index() const {
462     TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened.");
463     return indices_[0];
464   }
buf()465   BufPtr buf() const {
466     return buf_;
467   }
468 
set_buf(BufPtr buf)469   void set_buf(BufPtr buf) {
470     buf_ = std::move(buf);
471   }
472 
set_indices(std::vector<ExprPtr> indices)473   void set_indices(std::vector<ExprPtr> indices) {
474     indices_ = std::move(indices);
475   }
476 
477   static ExprHandle make(
478       Dtype dtype,
479       const BufHandle& buf,
480       const std::vector<ExprHandle>& indices);
481   static ExprHandle make(
482       const BufHandle& buf,
483       const std::vector<ExprHandle>& indices);
484 
485   Load(Dtype dtype, BufPtr base_handle, std::vector<ExprPtr> indices);
486   Load(const BufPtr& base_handle, const std::vector<ExprPtr>& indices);
487 
488  private:
489   BufPtr buf_;
490   std::vector<ExprPtr> indices_;
491 };
492 
493 class TORCH_API Broadcast : public ExprNode<Broadcast> {
494  public:
value()495   ExprPtr value() const {
496     return value_;
497   }
498 
set_value(ExprPtr value)499   void set_value(ExprPtr value) {
500     value_ = std::move(value);
501   }
502 
lanes()503   int64_t lanes() const {
504     return lanes_;
505   }
make(const ExprHandle & value,int64_t lanes)506   static ExprHandle make(const ExprHandle& value, int64_t lanes) {
507     return ExprHandle(alloc<Broadcast>(value.node(), lanes));
508   }
Broadcast(ExprPtr value,int64_t lanes)509   Broadcast(ExprPtr value, int64_t lanes)
510       : ExprNodeBase(Dtype(value->dtype(), lanes)),
511         value_(std::move(value)),
512         lanes_(lanes) {}
513 
514  private:
515   ExprPtr value_;
516   int64_t lanes_;
517 };
518 
519 class TORCH_API IfThenElse : public ExprNode<IfThenElse> {
520  public:
condition()521   ExprPtr condition() const {
522     return condition_;
523   }
524 
525   // Lazily evaluated only if condition is true
true_value()526   ExprPtr true_value() const {
527     return true_;
528   }
529 
530   // Lazily evaluated only if condition is false
false_value()531   ExprPtr false_value() const {
532     return false_;
533   }
534 
set_condition(ExprPtr condition)535   void set_condition(ExprPtr condition) {
536     condition_ = std::move(condition);
537   }
538 
set_true_value(ExprPtr true_value)539   void set_true_value(ExprPtr true_value) {
540     true_ = std::move(true_value);
541   }
542 
set_false_value(ExprPtr false_value)543   void set_false_value(ExprPtr false_value) {
544     false_ = std::move(false_value);
545   }
546 
make(const ExprHandle & c,const ExprHandle & t,const ExprHandle & f)547   static ExprHandle make(
548       const ExprHandle& c,
549       const ExprHandle& t,
550       const ExprHandle& f) {
551     if (!c.dtype().is_integral()) {
552       throw unsupported_dtype();
553     }
554     if (c.dtype().lanes() != 1) {
555       throw unsupported_dtype();
556     }
557     if (t.dtype() != f.dtype()) {
558       throw malformed_input("Bad dtype in IfThenElse");
559     }
560     return ExprHandle(alloc<IfThenElse>(c.node(), t.node(), f.node()));
561   }
562 
IfThenElse(ExprPtr c,ExprPtr t,ExprPtr f)563   IfThenElse(ExprPtr c, ExprPtr t, ExprPtr f)
564       : ExprNodeBase(t->dtype()),
565         condition_(std::move(c)),
566         true_(std::move(t)),
567         false_(std::move(f)) {}
568 
569  private:
570   ExprPtr condition_;
571   ExprPtr true_;
572   ExprPtr false_;
573 };
574 
575 class TORCH_API CompareSelect : public ExprNode<CompareSelect> {
576  public:
compare_select_op()577   CompareSelectOperation compare_select_op() const {
578     return compare_op_;
579   }
lhs()580   ExprPtr lhs() const {
581     return this->lhs_;
582   }
rhs()583   ExprPtr rhs() const {
584     return this->rhs_;
585   }
ret_val1()586   ExprPtr ret_val1() const {
587     return this->ret_val1_;
588   }
ret_val2()589   ExprPtr ret_val2() const {
590     return this->ret_val2_;
591   }
592 
set_lhs(ExprPtr lhs)593   void set_lhs(ExprPtr lhs) {
594     lhs_ = std::move(lhs);
595   }
596 
set_rhs(ExprPtr rhs)597   void set_rhs(ExprPtr rhs) {
598     rhs_ = std::move(rhs);
599   }
600 
set_ret_val1(ExprPtr ret_val1)601   void set_ret_val1(ExprPtr ret_val1) {
602     ret_val1_ = std::move(ret_val1);
603   }
604 
set_ret_val2(ExprPtr ret_val2)605   void set_ret_val2(ExprPtr ret_val2) {
606     ret_val2_ = std::move(ret_val2);
607   }
608 
bias()609   CompareSelectBias bias() const {
610     return bias_;
611   }
612 
613   static ExprHandle make(
614       const ExprHandle& lhs,
615       const ExprHandle& rhs,
616       CompareSelectOperation cmp_op,
617       CompareSelectBias bias = kUnbiased) {
618     if (lhs.dtype() != rhs.dtype()) {
619       throw malformed_input("bad dtype in CompareSelect");
620     }
621     return ExprHandle(alloc<CompareSelect>(
622         lhs.node(),
623         rhs.node(),
624         IntImm::make(1).node(),
625         IntImm::make(0).node(),
626         cmp_op,
627         bias));
628   }
629 
630   static ExprHandle make(
631       const ExprHandle& lhs,
632       const ExprHandle& rhs,
633       const ExprHandle& ret_val1,
634       const ExprHandle& ret_val2,
635       CompareSelectOperation cmp_op,
636       CompareSelectBias bias = kUnbiased) {
637     if (lhs.dtype() != rhs.dtype() || ret_val1.dtype() != ret_val2.dtype()) {
638       throw malformed_input("bad dtype in CompareSelect");
639     }
640     return ExprHandle(alloc<CompareSelect>(
641         lhs.node(),
642         rhs.node(),
643         ret_val1.node(),
644         ret_val2.node(),
645         cmp_op,
646         bias));
647   }
648 
649   CompareSelect(
650       ExprPtr lhs,
651       ExprPtr rhs,
652       ExprPtr ret_val1,
653       ExprPtr ret_val2,
654       CompareSelectOperation cmp_op,
655       CompareSelectBias bias = kUnbiased)
656       : ExprNodeBase(ret_val1->dtype()),
657         lhs_(std::move(lhs)),
658         rhs_(std::move(rhs)),
659         ret_val1_(std::move(ret_val1)),
660         ret_val2_(std::move(ret_val2)),
661         compare_op_(cmp_op),
662         bias_(bias) {}
663 
664   CompareSelect(
665       ExprPtr lhs,
666       ExprPtr rhs,
667       CompareSelectOperation cmp_op,
668       CompareSelectBias bias = kUnbiased)
ExprNodeBase(kInt)669       : ExprNodeBase(kInt),
670         lhs_(std::move(lhs)),
671         rhs_(std::move(rhs)),
672         ret_val1_(alloc<IntImm>(1)),
673         ret_val2_(alloc<IntImm>(0)),
674         compare_op_(cmp_op),
675         bias_(bias) {}
676 
677  private:
678   ExprPtr lhs_;
679   ExprPtr rhs_;
680   ExprPtr ret_val1_;
681   ExprPtr ret_val2_;
682   CompareSelectOperation compare_op_;
683   CompareSelectBias bias_;
684 };
685 
686 enum IntrinsicsOp {
687   kSin,
688   kCos,
689   kTan,
690   kAsin,
691   kAcos,
692   kAtan,
693   kAtan2,
694   kSinh,
695   kCosh,
696   kTanh,
697   kSigmoid,
698   kExp,
699   kExpm1,
700   kAbs,
701   kLog,
702   kLog2,
703   kLog10,
704   kLog1p,
705   kErf,
706   kErfc,
707   kSqrt,
708   kRsqrt,
709   kPow,
710   kCeil,
711   kFloor,
712   kRound,
713   kTrunc,
714   kFmod,
715   kRemainder,
716   kLgamma,
717   kFrac,
718   kIsNan,
719   kRand, // We need more discussions on this. Should we consider stateful?
720   kMaxIntrinsicsOp,
721 };
722 
723 class TORCH_API Intrinsics : public ExprNode<Intrinsics> {
724  public:
make(IntrinsicsOp op_type,const ExprHandle & v1)725   static ExprHandle make(IntrinsicsOp op_type, const ExprHandle& v1) {
726     return ExprHandle(alloc<Intrinsics>(op_type, v1.node()));
727   }
728 
make(IntrinsicsOp op_type,const ExprHandle & v1,const ExprHandle & v2)729   static ExprHandle make(
730       IntrinsicsOp op_type,
731       const ExprHandle& v1,
732       const ExprHandle& v2) {
733     return ExprHandle(alloc<Intrinsics>(op_type, v1.node(), v2.node()));
734   }
735 
make(IntrinsicsOp op_type,const std::vector<ExprHandle> & params)736   static ExprHandle make(
737       IntrinsicsOp op_type,
738       const std::vector<ExprHandle>& params) {
739     std::vector<ExprPtr> params_nodes(params.size());
740     for (size_t i = 0; i < params.size(); i++) {
741       params_nodes[i] = params[i].node();
742     }
743     return ExprHandle(alloc<Intrinsics>(op_type, params_nodes));
744   }
745 
make(IntrinsicsOp op_type,Dtype dtype)746   static ExprHandle make(IntrinsicsOp op_type, Dtype dtype) {
747     return ExprHandle(alloc<Intrinsics>(op_type, dtype));
748   }
749 
op_type()750   IntrinsicsOp op_type() const {
751     return op_type_;
752   }
753 
func_name()754   std::string func_name() const {
755     switch (op_type()) {
756       case kSin:
757         return "sin";
758       case kCos:
759         return "cos";
760       case kTan:
761         return "tan";
762       case kAsin:
763         return "asin";
764       case kAcos:
765         return "acos";
766       case kAtan:
767         return "atan";
768       case kAtan2:
769         return "atan2";
770       case kSinh:
771         return "sinh";
772       case kCosh:
773         return "cosh";
774       case kTanh:
775         return "tanh";
776       case kSigmoid:
777         return "sigmoid";
778       case kExp:
779         return "exp";
780       case kAbs:
781         return "abs";
782       case kLog:
783         return "log";
784       case kLog2:
785         return "log2";
786       case kLog10:
787         return "log10";
788       case kLog1p:
789         return "log1p";
790       case kErf:
791         return "erf";
792       case kSqrt:
793         return "sqrt";
794       case kRsqrt:
795         return "rsqrt";
796       case kPow:
797         return "pow";
798       case kCeil:
799         return "ceil";
800       case kFloor:
801         return "floor";
802       case kRound:
803         return "round";
804       case kTrunc:
805         return "trunc";
806       case kRand:
807         return "rand";
808       case kFmod:
809         return "fmod";
810       case kRemainder:
811         return "remainder";
812       case kLgamma:
813         return "lgamma";
814       case kExpm1:
815         return "expm1";
816       case kErfc:
817         return "erfc";
818       case kFrac:
819         return "frac";
820       case kIsNan:
821         return "isnan";
822       default:
823         throw std::runtime_error(
824             "invalid op_type: " + std::to_string(op_type()));
825     }
826   }
827 
Intrinsics(IntrinsicsOp op_type,Dtype dtype)828   Intrinsics(IntrinsicsOp op_type, Dtype dtype)
829       : ExprNodeBase(IntrinsicsDtype(op_type, dtype)),
830         params_({}),
831         op_type_(op_type) {
832     if (OpArgCount(op_type) != 0) {
833       throw malformed_input("bad arg count in Intrinsics");
834     }
835   }
836 
Intrinsics(IntrinsicsOp op_type,ExprPtr v1)837   Intrinsics(IntrinsicsOp op_type, ExprPtr v1)
838       : ExprNodeBase(IntrinsicsDtype(op_type, v1->dtype())),
839         params_({std::move(v1)}),
840         op_type_(op_type) {
841     if (OpArgCount(op_type) != 1) {
842       throw malformed_input("bad arg count in Intrinsics");
843     }
844   }
845 
Intrinsics(IntrinsicsOp op_type,ExprPtr v1,ExprPtr v2)846   Intrinsics(IntrinsicsOp op_type, ExprPtr v1, ExprPtr v2)
847       : ExprNodeBase(IntrinsicsDtype(op_type, v1->dtype(), v2->dtype())),
848         params_({std::move(v1), std::move(v2)}),
849         op_type_(op_type) {
850     if (OpArgCount(op_type) != 2) {
851       throw malformed_input("bad arg count in Intrinsics");
852     }
853   }
854 
Intrinsics(IntrinsicsOp op_type,const std::vector<ExprPtr> & params)855   Intrinsics(IntrinsicsOp op_type, const std::vector<ExprPtr>& params)
856       : ExprNodeBase(IntrinsicsDtype(op_type, params)),
857         params_(params),
858         op_type_(op_type) {
859     if (OpArgCount(op_type) != nparams()) {
860       throw malformed_input("bad arg count in Intrinsics");
861     }
862   }
863 
Intrinsics(IntrinsicsOp op_type,Dtype dtype,const std::vector<ExprPtr> & params)864   Intrinsics(
865       IntrinsicsOp op_type,
866       Dtype dtype,
867       const std::vector<ExprPtr>& params)
868       : ExprNodeBase(IntrinsicsDtype(op_type, dtype)),
869         params_(params),
870         op_type_(op_type) {
871     if (OpArgCount(op_type) != nparams()) {
872       throw malformed_input("bad arg count in Intrinsics");
873     }
874   }
875 
isPure()876   bool isPure() const {
877     return op_type_ != kRand;
878   }
879 
nparams()880   size_t nparams() const {
881     return params_.size();
882   }
883 
param(size_t index)884   ExprPtr param(size_t index) const {
885     return params_[index];
886   }
params()887   const std::vector<ExprPtr>& params() const {
888     return params_;
889   }
890 
set_params(std::vector<ExprPtr> params)891   void set_params(std::vector<ExprPtr> params) {
892     params_ = std::move(params);
893   }
894 
895   static size_t OpArgCount(IntrinsicsOp op_type);
896 
897  private:
898   static Dtype IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1);
899   static Dtype IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1, Dtype dt2);
900   static Dtype IntrinsicsDtype(
901       IntrinsicsOp op_type,
902       const std::vector<ExprPtr>& params);
903 
904   std::vector<ExprPtr> params_;
905   IntrinsicsOp op_type_;
906 };
907 
908 TORCH_API std::vector<ExprPtr> ExprHandleVectorToExprVector(
909     const std::vector<ExprHandle>&);
910 TORCH_API std::vector<ExprHandle> ExprVectorToExprHandleVector(
911     const std::vector<ExprPtr>&);
912 TORCH_API std::vector<VarPtr> VarHandleVectorToVarVector(
913     const std::vector<VarHandle>&);
914 TORCH_API std::vector<VarHandle> VarVectorToVarHandleVector(
915     const std::vector<VarPtr>&);
916 TORCH_API ExprPtr flatten_index(
917     const std::vector<ExprPtr>& dims,
918     const std::vector<ExprPtr>& indices,
919     const std::vector<ExprPtr>& strides);
920 
921 } // namespace tensorexpr
922 } // namespace jit
923 } // namespace torch
924