xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/ir_visitor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/ir_visitor.h>
2 
3 #include <torch/csrc/jit/tensorexpr/ir.h>
4 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
5 #include <torch/csrc/jit/tensorexpr/reduction.h>
6 #include <torch/csrc/jit/tensorexpr/tensor.h>
7 
8 #include <c10/util/irange.h>
9 
10 namespace torch::jit::tensorexpr {
11 
12 template <
13     typename Op,
14     std::enable_if_t<std::is_same_v<
15         decltype(detail::bin_op_deducer(std::declval<Op>())),
16         void>>* = nullptr>
visit_binary_op(const NodePtr<Op> & v,IRVisitor * visitor)17 static void visit_binary_op(const NodePtr<Op>& v, IRVisitor* visitor) {
18   v->lhs()->accept(visitor);
19   v->rhs()->accept(visitor);
20 }
21 
visit(const AddPtr & v)22 void IRVisitor::visit(const AddPtr& v) {
23   visit_binary_op(v, this);
24 }
25 
visit(const SubPtr & v)26 void IRVisitor::visit(const SubPtr& v) {
27   visit_binary_op(v, this);
28 }
29 
visit(const MulPtr & v)30 void IRVisitor::visit(const MulPtr& v) {
31   visit_binary_op(v, this);
32 }
33 
visit(const DivPtr & v)34 void IRVisitor::visit(const DivPtr& v) {
35   visit_binary_op(v, this);
36 }
37 
visit(const ModPtr & v)38 void IRVisitor::visit(const ModPtr& v) {
39   visit_binary_op(v, this);
40 }
41 
visit(const MaxPtr & v)42 void IRVisitor::visit(const MaxPtr& v) {
43   visit_binary_op(v, this);
44 }
45 
visit(const MinPtr & v)46 void IRVisitor::visit(const MinPtr& v) {
47   visit_binary_op(v, this);
48 }
49 
visit(const AndPtr & v)50 void IRVisitor::visit(const AndPtr& v) {
51   visit_binary_op(v, this);
52 }
53 
visit(const OrPtr & v)54 void IRVisitor::visit(const OrPtr& v) {
55   visit_binary_op(v, this);
56 }
57 
visit(const XorPtr & v)58 void IRVisitor::visit(const XorPtr& v) {
59   visit_binary_op(v, this);
60 }
61 
visit(const LshiftPtr & v)62 void IRVisitor::visit(const LshiftPtr& v) {
63   visit_binary_op(v, this);
64 }
65 
visit(const RshiftPtr & v)66 void IRVisitor::visit(const RshiftPtr& v) {
67   visit_binary_op(v, this);
68 }
69 
visit(const CompareSelectPtr & v)70 void IRVisitor::visit(const CompareSelectPtr& v) {
71   v->lhs()->accept(this);
72   v->rhs()->accept(this);
73   v->ret_val1()->accept(this);
74   v->ret_val2()->accept(this);
75 }
76 
77 #define IMM_VISIT(Type, Name) \
78   void IRVisitor::visit(const Name##ImmPtr& v) {}
79 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT);
80 #undef IMM_VISIT
81 
visit(const CastPtr & v)82 void IRVisitor::visit(const CastPtr& v) {
83   v->src_value()->accept(this);
84 }
visit(const BitCastPtr & v)85 void IRVisitor::visit(const BitCastPtr& v) {
86   v->src_value()->accept(this);
87 }
visit(const VarPtr & v)88 void IRVisitor::visit(const VarPtr& v) {}
89 
visit(const RampPtr & v)90 void IRVisitor::visit(const RampPtr& v) {
91   v->base()->accept(this);
92   v->stride()->accept(this);
93 }
94 
visit(const LoadPtr & v)95 void IRVisitor::visit(const LoadPtr& v) {
96   v->buf()->accept(this);
97   for (const ExprPtr& ind : v->indices()) {
98     ind->accept(this);
99   }
100 }
101 
visit(const BufPtr & v)102 void IRVisitor::visit(const BufPtr& v) {
103   v->base_handle()->accept(this);
104   if (v->qscale()) {
105     v->qscale()->accept(this);
106   }
107   if (v->qzero()) {
108     v->qzero()->accept(this);
109   }
110 }
111 
visit(const StorePtr & v)112 void IRVisitor::visit(const StorePtr& v) {
113   v->buf()->accept(this);
114   for (const ExprPtr& ind : v->indices()) {
115     ind->accept(this);
116   }
117   v->value()->accept(this);
118 }
119 
visit(const AtomicAddPtr & v)120 void IRVisitor::visit(const AtomicAddPtr& v) {
121   v->buf()->accept(this);
122   for (const ExprPtr& ind : v->indices()) {
123     ind->accept(this);
124   }
125   v->value()->accept(this);
126 }
127 
visit(const SyncThreadsPtr & v)128 void IRVisitor::visit(const SyncThreadsPtr& v) {}
129 
visit(const ExternalCallPtr & v)130 void IRVisitor::visit(const ExternalCallPtr& v) {
131   v->buf()->accept(this);
132   for (const BufPtr& buf_arg : v->buf_args()) {
133     buf_arg->accept(this);
134   }
135   for (const ExprPtr& arg : v->args()) {
136     arg->accept(this);
137   }
138 }
139 
visit(const ExternalCallWithAllocPtr & v)140 void IRVisitor::visit(const ExternalCallWithAllocPtr& v) {
141   for (const auto& buf_out_arg : v->buf_out_args()) {
142     buf_out_arg->accept(this);
143   }
144   for (const auto& buf_arg : v->buf_args()) {
145     buf_arg->accept(this);
146   }
147   for (const auto& arg : v->args()) {
148     arg->accept(this);
149   }
150 }
151 
visit(const FreeExtPtr & v)152 void IRVisitor::visit(const FreeExtPtr& v) {
153   for (const auto& buf : v->bufs()) {
154     buf->accept(this);
155   }
156 }
157 
visit(const BlockPtr & v)158 void IRVisitor::visit(const BlockPtr& v) {
159   for (const StmtPtr& s : *v) {
160     s->accept(this);
161   }
162 }
163 
visit(const ForPtr & v)164 void IRVisitor::visit(const ForPtr& v) {
165   v->var()->accept(this);
166   v->start()->accept(this);
167   v->stop()->accept(this);
168   if (v->body()) {
169     v->body()->accept(this);
170   }
171 }
172 
visit(const BroadcastPtr & v)173 void IRVisitor::visit(const BroadcastPtr& v) {
174   v->value()->accept(this);
175 }
176 
visit(const IfThenElsePtr & v)177 void IRVisitor::visit(const IfThenElsePtr& v) {
178   v->condition()->accept(this);
179   v->true_value()->accept(this);
180   v->false_value()->accept(this);
181 }
182 
visit(const IntrinsicsPtr & v)183 void IRVisitor::visit(const IntrinsicsPtr& v) {
184   for (const auto i : c10::irange(v->nparams())) {
185     v->param(i)->accept(this);
186   }
187 }
188 
visit(const AllocatePtr & v)189 void IRVisitor::visit(const AllocatePtr& v) {
190   v->buffer_var()->accept(this);
191   std::vector<ExprPtr> dims = v->dims();
192   for (const ExprPtr& dim : dims) {
193     dim->accept(this);
194   }
195 }
196 
visit(const FreePtr & v)197 void IRVisitor::visit(const FreePtr& v) {
198   v->buffer_var()->accept(this);
199 }
200 
visit(const PlacementAllocatePtr & v)201 void IRVisitor::visit(const PlacementAllocatePtr& v) {
202   v->buf()->accept(this);
203   v->buf_to_reuse()->accept(this);
204 }
205 
visit(const LetPtr & v)206 void IRVisitor::visit(const LetPtr& v) {
207   v->var()->accept(this);
208   v->value()->accept(this);
209 }
210 
visit(const CondPtr & v)211 void IRVisitor::visit(const CondPtr& v) {
212   ExprPtr condition = v->condition();
213   StmtPtr true_stmt = v->true_stmt();
214   StmtPtr false_stmt = v->false_stmt();
215   condition->accept(this);
216   if (true_stmt) {
217     true_stmt->accept(this);
218   }
219   if (false_stmt) {
220     false_stmt->accept(this);
221   }
222 }
223 
visit(const TermPtr & v)224 void IRVisitor::visit(const TermPtr& v) {
225   v->scalar()->accept(this);
226   for (const auto& t : v->variables()) {
227     t->accept(this);
228   }
229 }
230 
visit(const PolynomialPtr & v)231 void IRVisitor::visit(const PolynomialPtr& v) {
232   v->scalar()->accept(this);
233   for (const auto& t : v->variables()) {
234     t->accept(this);
235   }
236 }
237 
visit(const RoundOffPtr & v)238 void IRVisitor::visit(const RoundOffPtr& v) {
239   v->lhs()->accept(this);
240   v->rhs()->accept(this);
241 }
242 
visit(const MaxTermPtr & v)243 void IRVisitor::visit(const MaxTermPtr& v) {
244   if (v->scalar()) {
245     v->scalar()->accept(this);
246   }
247   for (const auto& t : v->variables()) {
248     t->accept(this);
249   }
250 }
251 
visit(const MinTermPtr & v)252 void IRVisitor::visit(const MinTermPtr& v) {
253   if (v->scalar()) {
254     v->scalar()->accept(this);
255   }
256   for (const auto& t : v->variables()) {
257     t->accept(this);
258   }
259 }
260 
visit(const ReduceOpPtr & v)261 void IRVisitor::visit(const ReduceOpPtr& v) {
262   v->body()->accept(this);
263 
264   for (const auto& r : v->reduce_args()) {
265     r->accept(this);
266   }
267 }
268 
269 } // namespace torch::jit::tensorexpr
270