1 #include <torch/csrc/jit/tensorexpr/ir_visitor.h>
2
3 #include <torch/csrc/jit/tensorexpr/ir.h>
4 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
5 #include <torch/csrc/jit/tensorexpr/reduction.h>
6 #include <torch/csrc/jit/tensorexpr/tensor.h>
7
8 #include <c10/util/irange.h>
9
10 namespace torch::jit::tensorexpr {
11
12 template <
13 typename Op,
14 std::enable_if_t<std::is_same_v<
15 decltype(detail::bin_op_deducer(std::declval<Op>())),
16 void>>* = nullptr>
visit_binary_op(const NodePtr<Op> & v,IRVisitor * visitor)17 static void visit_binary_op(const NodePtr<Op>& v, IRVisitor* visitor) {
18 v->lhs()->accept(visitor);
19 v->rhs()->accept(visitor);
20 }
21
visit(const AddPtr & v)22 void IRVisitor::visit(const AddPtr& v) {
23 visit_binary_op(v, this);
24 }
25
visit(const SubPtr & v)26 void IRVisitor::visit(const SubPtr& v) {
27 visit_binary_op(v, this);
28 }
29
visit(const MulPtr & v)30 void IRVisitor::visit(const MulPtr& v) {
31 visit_binary_op(v, this);
32 }
33
visit(const DivPtr & v)34 void IRVisitor::visit(const DivPtr& v) {
35 visit_binary_op(v, this);
36 }
37
visit(const ModPtr & v)38 void IRVisitor::visit(const ModPtr& v) {
39 visit_binary_op(v, this);
40 }
41
visit(const MaxPtr & v)42 void IRVisitor::visit(const MaxPtr& v) {
43 visit_binary_op(v, this);
44 }
45
visit(const MinPtr & v)46 void IRVisitor::visit(const MinPtr& v) {
47 visit_binary_op(v, this);
48 }
49
visit(const AndPtr & v)50 void IRVisitor::visit(const AndPtr& v) {
51 visit_binary_op(v, this);
52 }
53
visit(const OrPtr & v)54 void IRVisitor::visit(const OrPtr& v) {
55 visit_binary_op(v, this);
56 }
57
visit(const XorPtr & v)58 void IRVisitor::visit(const XorPtr& v) {
59 visit_binary_op(v, this);
60 }
61
visit(const LshiftPtr & v)62 void IRVisitor::visit(const LshiftPtr& v) {
63 visit_binary_op(v, this);
64 }
65
visit(const RshiftPtr & v)66 void IRVisitor::visit(const RshiftPtr& v) {
67 visit_binary_op(v, this);
68 }
69
visit(const CompareSelectPtr & v)70 void IRVisitor::visit(const CompareSelectPtr& v) {
71 v->lhs()->accept(this);
72 v->rhs()->accept(this);
73 v->ret_val1()->accept(this);
74 v->ret_val2()->accept(this);
75 }
76
77 #define IMM_VISIT(Type, Name) \
78 void IRVisitor::visit(const Name##ImmPtr& v) {}
79 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT);
80 #undef IMM_VISIT
81
visit(const CastPtr & v)82 void IRVisitor::visit(const CastPtr& v) {
83 v->src_value()->accept(this);
84 }
visit(const BitCastPtr & v)85 void IRVisitor::visit(const BitCastPtr& v) {
86 v->src_value()->accept(this);
87 }
visit(const VarPtr & v)88 void IRVisitor::visit(const VarPtr& v) {}
89
visit(const RampPtr & v)90 void IRVisitor::visit(const RampPtr& v) {
91 v->base()->accept(this);
92 v->stride()->accept(this);
93 }
94
visit(const LoadPtr & v)95 void IRVisitor::visit(const LoadPtr& v) {
96 v->buf()->accept(this);
97 for (const ExprPtr& ind : v->indices()) {
98 ind->accept(this);
99 }
100 }
101
visit(const BufPtr & v)102 void IRVisitor::visit(const BufPtr& v) {
103 v->base_handle()->accept(this);
104 if (v->qscale()) {
105 v->qscale()->accept(this);
106 }
107 if (v->qzero()) {
108 v->qzero()->accept(this);
109 }
110 }
111
visit(const StorePtr & v)112 void IRVisitor::visit(const StorePtr& v) {
113 v->buf()->accept(this);
114 for (const ExprPtr& ind : v->indices()) {
115 ind->accept(this);
116 }
117 v->value()->accept(this);
118 }
119
visit(const AtomicAddPtr & v)120 void IRVisitor::visit(const AtomicAddPtr& v) {
121 v->buf()->accept(this);
122 for (const ExprPtr& ind : v->indices()) {
123 ind->accept(this);
124 }
125 v->value()->accept(this);
126 }
127
visit(const SyncThreadsPtr & v)128 void IRVisitor::visit(const SyncThreadsPtr& v) {}
129
visit(const ExternalCallPtr & v)130 void IRVisitor::visit(const ExternalCallPtr& v) {
131 v->buf()->accept(this);
132 for (const BufPtr& buf_arg : v->buf_args()) {
133 buf_arg->accept(this);
134 }
135 for (const ExprPtr& arg : v->args()) {
136 arg->accept(this);
137 }
138 }
139
visit(const ExternalCallWithAllocPtr & v)140 void IRVisitor::visit(const ExternalCallWithAllocPtr& v) {
141 for (const auto& buf_out_arg : v->buf_out_args()) {
142 buf_out_arg->accept(this);
143 }
144 for (const auto& buf_arg : v->buf_args()) {
145 buf_arg->accept(this);
146 }
147 for (const auto& arg : v->args()) {
148 arg->accept(this);
149 }
150 }
151
visit(const FreeExtPtr & v)152 void IRVisitor::visit(const FreeExtPtr& v) {
153 for (const auto& buf : v->bufs()) {
154 buf->accept(this);
155 }
156 }
157
visit(const BlockPtr & v)158 void IRVisitor::visit(const BlockPtr& v) {
159 for (const StmtPtr& s : *v) {
160 s->accept(this);
161 }
162 }
163
visit(const ForPtr & v)164 void IRVisitor::visit(const ForPtr& v) {
165 v->var()->accept(this);
166 v->start()->accept(this);
167 v->stop()->accept(this);
168 if (v->body()) {
169 v->body()->accept(this);
170 }
171 }
172
visit(const BroadcastPtr & v)173 void IRVisitor::visit(const BroadcastPtr& v) {
174 v->value()->accept(this);
175 }
176
visit(const IfThenElsePtr & v)177 void IRVisitor::visit(const IfThenElsePtr& v) {
178 v->condition()->accept(this);
179 v->true_value()->accept(this);
180 v->false_value()->accept(this);
181 }
182
visit(const IntrinsicsPtr & v)183 void IRVisitor::visit(const IntrinsicsPtr& v) {
184 for (const auto i : c10::irange(v->nparams())) {
185 v->param(i)->accept(this);
186 }
187 }
188
visit(const AllocatePtr & v)189 void IRVisitor::visit(const AllocatePtr& v) {
190 v->buffer_var()->accept(this);
191 std::vector<ExprPtr> dims = v->dims();
192 for (const ExprPtr& dim : dims) {
193 dim->accept(this);
194 }
195 }
196
visit(const FreePtr & v)197 void IRVisitor::visit(const FreePtr& v) {
198 v->buffer_var()->accept(this);
199 }
200
visit(const PlacementAllocatePtr & v)201 void IRVisitor::visit(const PlacementAllocatePtr& v) {
202 v->buf()->accept(this);
203 v->buf_to_reuse()->accept(this);
204 }
205
visit(const LetPtr & v)206 void IRVisitor::visit(const LetPtr& v) {
207 v->var()->accept(this);
208 v->value()->accept(this);
209 }
210
visit(const CondPtr & v)211 void IRVisitor::visit(const CondPtr& v) {
212 ExprPtr condition = v->condition();
213 StmtPtr true_stmt = v->true_stmt();
214 StmtPtr false_stmt = v->false_stmt();
215 condition->accept(this);
216 if (true_stmt) {
217 true_stmt->accept(this);
218 }
219 if (false_stmt) {
220 false_stmt->accept(this);
221 }
222 }
223
visit(const TermPtr & v)224 void IRVisitor::visit(const TermPtr& v) {
225 v->scalar()->accept(this);
226 for (const auto& t : v->variables()) {
227 t->accept(this);
228 }
229 }
230
visit(const PolynomialPtr & v)231 void IRVisitor::visit(const PolynomialPtr& v) {
232 v->scalar()->accept(this);
233 for (const auto& t : v->variables()) {
234 t->accept(this);
235 }
236 }
237
visit(const RoundOffPtr & v)238 void IRVisitor::visit(const RoundOffPtr& v) {
239 v->lhs()->accept(this);
240 v->rhs()->accept(this);
241 }
242
visit(const MaxTermPtr & v)243 void IRVisitor::visit(const MaxTermPtr& v) {
244 if (v->scalar()) {
245 v->scalar()->accept(this);
246 }
247 for (const auto& t : v->variables()) {
248 t->accept(this);
249 }
250 }
251
visit(const MinTermPtr & v)252 void IRVisitor::visit(const MinTermPtr& v) {
253 if (v->scalar()) {
254 v->scalar()->accept(this);
255 }
256 for (const auto& t : v->variables()) {
257 t->accept(this);
258 }
259 }
260
visit(const ReduceOpPtr & v)261 void IRVisitor::visit(const ReduceOpPtr& v) {
262 v->body()->accept(this);
263
264 for (const auto& r : v->reduce_args()) {
265 r->accept(this);
266 }
267 }
268
269 } // namespace torch::jit::tensorexpr
270