xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/fwd_decls.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/core/ScalarType.h>
3 #include <memory>
4 
5 namespace torch {
6 namespace jit {
7 namespace tensorexpr {
8 
9 template <typename Node>
10 using NodePtr = std::shared_ptr<Node>;
11 
12 template <typename To, typename From>
to(const NodePtr<From> & x)13 NodePtr<To> to(const NodePtr<From>& x) {
14   return std::dynamic_pointer_cast<To>(x);
15 }
16 
17 template <typename To, typename From>
static_to(NodePtr<From> x)18 NodePtr<To> static_to(NodePtr<From> x) {
19   return std::static_pointer_cast<To>(x);
20 }
21 
22 template <typename Node, typename... Args>
alloc(Args &&...args)23 NodePtr<Node> alloc(Args&&... args) {
24   return std::make_shared<Node>(std::forward<Args>(args)...);
25 }
26 
27 class Buf;
28 class Expr;
29 class Stmt;
30 class Var;
31 
32 using BufPtr = NodePtr<Buf>;
33 using ExprPtr = NodePtr<Expr>;
34 using StmtPtr = NodePtr<Stmt>;
35 using VarPtr = NodePtr<Var>;
36 
37 class ExprHandle;
38 class VarHandle;
39 class BufHandle;
40 
41 class Add;
42 class And;
43 class BitCast;
44 class Broadcast;
45 class Cast;
46 class CompareSelect;
47 class Div;
48 class IfThenElse;
49 class Intrinsics;
50 class Let;
51 class Load;
52 class Lshift;
53 class Max;
54 class MaxTerm;
55 class Min;
56 class MinTerm;
57 class Mod;
58 class Mul;
59 class Or;
60 class Polynomial;
61 class Ramp;
62 class ReduceOp;
63 class RoundOff;
64 class Rshift;
65 class Store;
66 class Sub;
67 class Term;
68 class Xor;
69 using AddPtr = NodePtr<Add>;
70 using AndPtr = NodePtr<And>;
71 using BitCastPtr = NodePtr<BitCast>;
72 using BroadcastPtr = NodePtr<Broadcast>;
73 using CastPtr = NodePtr<Cast>;
74 using CompareSelectPtr = NodePtr<CompareSelect>;
75 using DivPtr = NodePtr<Div>;
76 using IfThenElsePtr = NodePtr<IfThenElse>;
77 using IntrinsicsPtr = NodePtr<Intrinsics>;
78 using LetPtr = NodePtr<Let>;
79 using LoadPtr = NodePtr<Load>;
80 using LshiftPtr = NodePtr<Lshift>;
81 using MaxPtr = NodePtr<Max>;
82 using MaxTermPtr = NodePtr<MaxTerm>;
83 using MinPtr = NodePtr<Min>;
84 using MinTermPtr = NodePtr<MinTerm>;
85 using ModPtr = NodePtr<Mod>;
86 using MulPtr = NodePtr<Mul>;
87 using OrPtr = NodePtr<Or>;
88 using PolynomialPtr = NodePtr<Polynomial>;
89 using RampPtr = NodePtr<Ramp>;
90 using ReduceOpPtr = NodePtr<ReduceOp>;
91 using RoundOffPtr = NodePtr<RoundOff>;
92 using RshiftPtr = NodePtr<Rshift>;
93 using StorePtr = NodePtr<Store>;
94 using SubPtr = NodePtr<Sub>;
95 using TermPtr = NodePtr<Term>;
96 using XorPtr = NodePtr<Xor>;
97 
98 class Allocate;
99 class AtomicAdd;
100 class Block;
101 class Cond;
102 class ExternalCall;
103 class ExternalCallWithAlloc;
104 class For;
105 class Free;
106 class FreeExt;
107 class PlacementAllocate;
108 class SyncThreads;
109 using AllocatePtr = NodePtr<Allocate>;
110 using AtomicAddPtr = NodePtr<AtomicAdd>;
111 using BlockPtr = NodePtr<Block>;
112 using CondPtr = NodePtr<Cond>;
113 using ExternalCallPtr = NodePtr<ExternalCall>;
114 using ExternalCallWithAllocPtr = NodePtr<ExternalCallWithAlloc>;
115 using ForPtr = NodePtr<For>;
116 using FreePtr = NodePtr<Free>;
117 using FreeExtPtr = NodePtr<FreeExt>;
118 using PlacementAllocatePtr = NodePtr<PlacementAllocate>;
119 using SyncThreadsPtr = NodePtr<SyncThreads>;
120 
121 #define IMM_DECLARE(Type, Name) \
122   class Name##Imm;              \
123   using Name##ImmPtr = NodePtr<Name##Imm>;
124 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_DECLARE);
125 #undef IMM_DECLARE
126 
127 } // namespace tensorexpr
128 } // namespace jit
129 } // namespace torch
130