xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/ir_mutator.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/ir_mutator.h>
2 
3 #include <torch/csrc/jit/tensorexpr/eval.h>
4 #include <torch/csrc/jit/tensorexpr/ir.h>
5 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
6 #include <torch/csrc/jit/tensorexpr/reduction.h>
7 
8 #include <c10/util/irange.h>
9 
10 namespace torch::jit::tensorexpr {
11 
12 template <
13     typename Op,
14     std::enable_if_t<std::is_same_v<
15         decltype(detail::bin_op_deducer(std::declval<Op>())),
16         void>>* = nullptr>
mutate_binary_op(NodePtr<Op> v,IRMutator * mutator,bool option=false)17 static ExprPtr mutate_binary_op(
18     NodePtr<Op> v,
19     IRMutator* mutator,
20     bool option = false) {
21   ExprPtr lhs = v->lhs();
22   ExprPtr rhs = v->rhs();
23   ExprPtr lhs_new = lhs->accept_mutator(mutator);
24   ExprPtr rhs_new = rhs->accept_mutator(mutator);
25   if (lhs != lhs_new) {
26     v->set_lhs(lhs_new);
27   }
28   if (rhs != rhs_new) {
29     v->set_rhs(rhs_new);
30   }
31   Dtype dtype_new =
32       BinaryOpDtype(lhs_new->dtype(), rhs_new->dtype(), ScalarType::Undefined);
33   if (dtype_new != v->dtype()) {
34     v->set_dtype(dtype_new);
35   }
36   return v;
37 }
38 
mutate(const AddPtr & v)39 ExprPtr IRMutator::mutate(const AddPtr& v) {
40   return mutate_binary_op(v, this);
41 }
42 
mutate(const SubPtr & v)43 ExprPtr IRMutator::mutate(const SubPtr& v) {
44   return mutate_binary_op(v, this);
45 }
46 
mutate(const MulPtr & v)47 ExprPtr IRMutator::mutate(const MulPtr& v) {
48   return mutate_binary_op(v, this);
49 }
50 
mutate(const DivPtr & v)51 ExprPtr IRMutator::mutate(const DivPtr& v) {
52   return mutate_binary_op(v, this);
53 }
54 
mutate(const ModPtr & v)55 ExprPtr IRMutator::mutate(const ModPtr& v) {
56   return mutate_binary_op(v, this);
57 }
58 
mutate(const AndPtr & v)59 ExprPtr IRMutator::mutate(const AndPtr& v) {
60   return mutate_binary_op(v, this);
61 }
62 
mutate(const OrPtr & v)63 ExprPtr IRMutator::mutate(const OrPtr& v) {
64   return mutate_binary_op(v, this);
65 }
66 
mutate(const XorPtr & v)67 ExprPtr IRMutator::mutate(const XorPtr& v) {
68   return mutate_binary_op(v, this);
69 }
70 
mutate(const LshiftPtr & v)71 ExprPtr IRMutator::mutate(const LshiftPtr& v) {
72   return mutate_binary_op(v, this);
73 }
74 
mutate(const RshiftPtr & v)75 ExprPtr IRMutator::mutate(const RshiftPtr& v) {
76   return mutate_binary_op(v, this);
77 }
78 
mutate(const MaxPtr & v)79 ExprPtr IRMutator::mutate(const MaxPtr& v) {
80   return mutate_binary_op(v, this, v->propagate_nans());
81 }
82 
mutate(const MinPtr & v)83 ExprPtr IRMutator::mutate(const MinPtr& v) {
84   return mutate_binary_op(v, this, v->propagate_nans());
85 }
86 
mutate(const CompareSelectPtr & v)87 ExprPtr IRMutator::mutate(const CompareSelectPtr& v) {
88   ExprPtr lhs = v->lhs();
89   ExprPtr rhs = v->rhs();
90   ExprPtr ret_val1 = v->ret_val1();
91   ExprPtr ret_val2 = v->ret_val2();
92   ExprPtr lhs_new = lhs->accept_mutator(this);
93   ExprPtr rhs_new = rhs->accept_mutator(this);
94   ExprPtr ret_val1_new = ret_val1->accept_mutator(this);
95   ExprPtr ret_val2_new = ret_val2->accept_mutator(this);
96   if (lhs != lhs_new) {
97     v->set_lhs(lhs_new);
98   }
99   if (rhs != rhs_new) {
100     v->set_rhs(rhs_new);
101   }
102   if (ret_val1 != ret_val1_new) {
103     v->set_ret_val1(ret_val1_new);
104   }
105   if (ret_val2 != ret_val2_new) {
106     v->set_ret_val2(ret_val2_new);
107   }
108   return v;
109 }
110 
111 // NOLINTNEXTLINE
112 #define IMM_MUTATE_DEFINE(_1, Name)                  \
113   ExprPtr IRMutator::mutate(const Name##ImmPtr& v) { \
114     return v;                                        \
115   }
116 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DEFINE);
117 #undef IMM_MUTATE_DEFINE
118 
mutate(const CastPtr & v)119 ExprPtr IRMutator::mutate(const CastPtr& v) {
120   ExprPtr src_value = v->src_value();
121   ExprPtr src_value_new = src_value->accept_mutator(this);
122   if (src_value != src_value_new) {
123     v->set_src_value(src_value_new);
124   }
125   return v;
126 }
127 
mutate(const BitCastPtr & v)128 ExprPtr IRMutator::mutate(const BitCastPtr& v) {
129   ExprPtr src_value = v->src_value();
130   ExprPtr src_value_new = src_value->accept_mutator(this);
131   if (src_value != src_value_new) {
132     v->set_src_value(src_value_new);
133   }
134   return v;
135 }
136 
mutate(const VarPtr & v)137 ExprPtr IRMutator::mutate(const VarPtr& v) {
138   return v;
139 }
140 
mutate(const RampPtr & v)141 ExprPtr IRMutator::mutate(const RampPtr& v) {
142   ExprPtr base = v->base();
143   ExprPtr stride = v->stride();
144   ExprPtr base_new = base->accept_mutator(this);
145   ExprPtr stride_new = stride->accept_mutator(this);
146   if (base != base_new) {
147     v->set_base(base_new);
148   }
149   if (stride != stride_new) {
150     v->set_stride(stride_new);
151   }
152   return v;
153 }
154 
mutate(const LoadPtr & v)155 ExprPtr IRMutator::mutate(const LoadPtr& v) {
156   BufPtr buf = v->buf();
157 
158   bool any_index_changed = false;
159   std::vector<ExprPtr> indices_new;
160   indices_new.reserve(v->indices().size());
161   for (const ExprPtr& ind : v->indices()) {
162     ExprPtr new_ind = ind->accept_mutator(this);
163     if (new_ind != ind) {
164       any_index_changed = true;
165     }
166     indices_new.push_back(new_ind);
167   }
168   BufPtr buf_new = to<Buf>(buf->accept_mutator(this));
169 
170   if (buf != buf_new) {
171     v->set_buf(buf_new);
172   }
173   if (any_index_changed) {
174     v->set_indices(indices_new);
175   }
176   return v;
177 }
178 
mutate(const BufPtr & v)179 ExprPtr IRMutator::mutate(const BufPtr& v) {
180   const VarPtr& var = v->base_handle();
181   const VarPtr& var_new = to<Var>(var->accept_mutator(this));
182   if (!var_new) {
183     return nullptr;
184   }
185 
186   bool dims_changed = false;
187   std::vector<ExprPtr> dims_old = v->dims();
188   std::vector<ExprPtr> dims_new(dims_old.size());
189   for (const auto i : c10::irange(dims_old.size())) {
190     dims_new[i] = dims_old[i]->accept_mutator(this);
191     dims_changed |= (dims_new[i] != dims_old[i]);
192   }
193 
194   if (var != var_new) {
195     v->set_base_handle(var_new);
196   }
197   if (dims_changed) {
198     v->set_dims(dims_new);
199   }
200 
201   ExprPtr qscale = v->qscale();
202   if (qscale) {
203     ExprPtr qscale_new = qscale->accept_mutator(this);
204     if (qscale != qscale_new) {
205       v->set_qscale(qscale_new);
206     }
207   }
208 
209   ExprPtr qzero = v->qzero();
210   if (qzero) {
211     ExprPtr qzero_new = qzero->accept_mutator(this);
212     if (qzero != qzero_new) {
213       v->set_qzero(qzero_new);
214     }
215   }
216 
217   return v;
218 }
219 
mutate(const BroadcastPtr & v)220 ExprPtr IRMutator::mutate(const BroadcastPtr& v) {
221   const ExprPtr& value = v->value();
222   const ExprPtr& value_new = value->accept_mutator(this);
223   if (value != value_new) {
224     v->set_value(value_new);
225   }
226   return v;
227 }
228 
mutate(const IfThenElsePtr & v)229 ExprPtr IRMutator::mutate(const IfThenElsePtr& v) {
230   ExprPtr condition = v->condition();
231   ExprPtr true_value = v->true_value();
232   ExprPtr false_value = v->false_value();
233   ExprPtr condition_new = condition->accept_mutator(this);
234   ExprPtr true_value_new = true_value->accept_mutator(this);
235   ExprPtr false_value_new = false_value->accept_mutator(this);
236 
237   if (condition != condition_new) {
238     v->set_condition(condition_new);
239   }
240   if (true_value != true_value_new) {
241     v->set_true_value(true_value_new);
242   }
243   if (false_value != false_value_new) {
244     v->set_false_value(false_value_new);
245   }
246   return v;
247 }
248 
mutate(const IntrinsicsPtr & v)249 ExprPtr IRMutator::mutate(const IntrinsicsPtr& v) {
250   std::vector<ExprPtr> params(v->nparams());
251   bool any_change = false;
252   for (size_t i = 0; i < v->nparams(); i++) {
253     const ExprPtr& value = v->param(i);
254     const ExprPtr& value_new = value->accept_mutator(this);
255     if (value != value_new) {
256       any_change = true;
257     }
258     params[i] = value_new;
259   }
260   if (any_change) {
261     v->set_params(params);
262   }
263   return v;
264 }
265 
mutate(const TermPtr & v)266 ExprPtr IRMutator::mutate(const TermPtr& v) {
267   ExprPtr newScalar = v->scalar()->accept_mutator(this);
268 
269   std::vector<ExprPtr> variables;
270   for (const auto& t : v->variables()) {
271     variables.push_back(t->accept_mutator(this));
272   }
273   return alloc<Term>(v->hasher(), newScalar, variables);
274 }
275 
mutate(const PolynomialPtr & v)276 ExprPtr IRMutator::mutate(const PolynomialPtr& v) {
277   ExprPtr newScalar = v->scalar()->accept_mutator(this);
278 
279   std::vector<TermPtr> variables;
280   for (const auto& t : v->variables()) {
281     variables.push_back(static_to<Term>(t->accept_mutator(this)));
282   }
283   return alloc<Polynomial>(v->hasher(), newScalar, variables);
284 }
285 
mutate(const RoundOffPtr & v)286 ExprPtr IRMutator::mutate(const RoundOffPtr& v) {
287   return alloc<RoundOff>(
288       v->lhs()->accept_mutator(this), v->rhs()->accept_mutator(this));
289 }
290 
mutate(const MaxTermPtr & v)291 ExprPtr IRMutator::mutate(const MaxTermPtr& v) {
292   ExprPtr newScalar = nullptr;
293   if (v->scalar()) {
294     newScalar = v->scalar()->accept_mutator(this);
295   }
296 
297   std::vector<ExprPtr> variables;
298   for (const auto& t : v->variables()) {
299     variables.push_back(t->accept_mutator(this));
300   }
301   return alloc<MaxTerm>(v->hasher(), newScalar, v->propagate_nans(), variables);
302 }
303 
mutate(const MinTermPtr & v)304 ExprPtr IRMutator::mutate(const MinTermPtr& v) {
305   ExprPtr newScalar = nullptr;
306   if (v->scalar()) {
307     newScalar = v->scalar()->accept_mutator(this);
308   }
309 
310   std::vector<ExprPtr> variables;
311   for (const auto& t : v->variables()) {
312     variables.push_back(t->accept_mutator(this));
313   }
314   return alloc<MinTerm>(v->hasher(), newScalar, v->propagate_nans(), variables);
315 }
316 
mutate(const ReduceOpPtr & v)317 ExprPtr IRMutator::mutate(const ReduceOpPtr& v) {
318   ExprPtr body_new = v->body()->accept_mutator(this);
319 
320   std::vector<VarPtr> new_reduce_args;
321   for (const auto& r : v->reduce_args()) {
322     new_reduce_args.push_back(static_to<Var>(r->accept_mutator(this)));
323   }
324 
325   return alloc<ReduceOp>(body_new, new_reduce_args, v->reducer());
326 }
327 
mutate(const ForPtr & v)328 StmtPtr IRMutator::mutate(const ForPtr& v) {
329   const ExprPtr& var = v->var();
330   ExprPtr start = v->start();
331   ExprPtr stop = v->stop();
332   StmtPtr body = v->body();
333   LoopOptions loop_options = v->loop_options();
334   const ExprPtr& var_new_expr = var->accept_mutator(this);
335   const VarPtr& var_new = to<Var>(var_new_expr);
336   ExprPtr start_new = start->accept_mutator(this);
337   ExprPtr stop_new = stop->accept_mutator(this);
338   StmtPtr body_new = body->accept_mutator(this);
339   if (!body_new) {
340     return nullptr;
341   }
342   if (body != body_new) {
343     v->set_body(body_new);
344   }
345   if (var != var_new) {
346     v->set_var(var_new);
347   }
348   if (start != start_new) {
349     v->set_start(start_new);
350   }
351   if (stop != stop_new) {
352     v->set_stop(stop_new);
353   }
354   return v;
355 }
356 
mutate(const BlockPtr & v)357 StmtPtr IRMutator::mutate(const BlockPtr& v) {
358   bool any_change = false;
359 
360   std::vector<StmtPtr> stmts;
361   for (const StmtPtr& stmt : *v) {
362     StmtPtr stmt_new = stmt->accept_mutator(this);
363     if (stmt != stmt_new) {
364       any_change = true;
365     } else {
366       stmt_new = Stmt::clone(stmt);
367     }
368     if (stmt_new) {
369       stmts.push_back(stmt_new);
370     }
371   }
372   if (any_change) {
373     v->set_stmts(stmts);
374   }
375   return v;
376 }
377 
mutate(const StorePtr & v)378 StmtPtr IRMutator::mutate(const StorePtr& v) {
379   BufPtr buf = v->buf();
380 
381   bool any_index_changed = false;
382   std::vector<ExprPtr> indices_new;
383   for (const ExprPtr& ind : v->indices()) {
384     ExprPtr new_ind = ind->accept_mutator(this);
385     if (new_ind != ind) {
386       any_index_changed = true;
387     }
388     indices_new.push_back(new_ind);
389   }
390   const ExprPtr& value = v->value();
391   BufPtr buf_new = to<Buf>(buf->accept_mutator(this));
392   const ExprPtr& value_new = value->accept_mutator(this);
393 
394   if (buf != buf_new) {
395     v->set_buf(buf_new);
396   }
397   if (any_index_changed) {
398     v->set_indices(indices_new);
399   }
400   if (value != value_new) {
401     v->set_value(value_new);
402   }
403   return v;
404 }
405 
mutate(const AtomicAddPtr & v)406 StmtPtr IRMutator::mutate(const AtomicAddPtr& v) {
407   BufPtr buf = v->buf();
408 
409   bool any_index_changed = false;
410   std::vector<ExprPtr> indices_new;
411   for (const ExprPtr& ind : v->indices()) {
412     ExprPtr new_ind = ind->accept_mutator(this);
413     if (new_ind != ind) {
414       any_index_changed = true;
415     }
416     indices_new.push_back(new_ind);
417   }
418   const ExprPtr& value = v->value();
419   BufPtr buf_new = to<Buf>(buf->accept_mutator(this));
420   const ExprPtr& value_new = value->accept_mutator(this);
421 
422   if (buf != buf_new) {
423     v->set_buf(buf_new);
424   }
425   if (any_index_changed) {
426     v->set_indices(indices_new);
427   }
428   if (value != value_new) {
429     v->set_value(value_new);
430   }
431   return v;
432 }
433 
mutate(const SyncThreadsPtr & v)434 StmtPtr IRMutator::mutate(const SyncThreadsPtr& v) {
435   return alloc<SyncThreads>();
436 }
437 
mutate(const ExternalCallPtr & v)438 StmtPtr IRMutator::mutate(const ExternalCallPtr& v) {
439   BufPtr buf = v->buf();
440   BufPtr buf_new = to<Buf>(buf->accept_mutator(this));
441   TORCH_INTERNAL_ASSERT(
442       buf_new, buildErrorMessage("IRMutator produced null for Buf."));
443 
444   bool buf_args_changed = false;
445   std::vector<BufPtr> buf_args_new;
446   buf_args_new.reserve(v->buf_args().size());
447   for (const BufPtr& buf_arg : v->buf_args()) {
448     BufPtr buf_arg_new = to<Buf>(buf_arg->accept_mutator(this));
449     TORCH_INTERNAL_ASSERT(
450         buf_arg_new, buildErrorMessage("IRMutator produced null for Buf."));
451     buf_args_new.push_back(buf_arg_new);
452     buf_args_changed |= buf_arg_new != buf_arg;
453   }
454 
455   bool args_changed = false;
456   std::vector<ExprPtr> args_new;
457   args_new.reserve(v->args().size());
458   for (const ExprPtr& arg : v->args()) {
459     ExprPtr arg_new = arg->accept_mutator(this);
460     args_new.push_back(arg_new);
461     args_changed |= arg_new != arg;
462   }
463 
464   if (buf != buf_new) {
465     v->set_buf(buf_new);
466   }
467   if (buf_args_changed) {
468     v->set_buf_args(buf_args_new);
469   }
470   if (args_changed) {
471     v->set_args(args_new);
472   }
473   return v;
474 }
475 
mutate(const ExternalCallWithAllocPtr & v)476 StmtPtr IRMutator::mutate(const ExternalCallWithAllocPtr& v) {
477   bool buf_out_args_changed = false;
478   std::vector<BufPtr> buf_out_args_new;
479   buf_out_args_new.reserve(v->buf_out_args().size());
480   for (const auto& buf_out_arg : v->buf_out_args()) {
481     BufPtr buf_out_arg_new = to<Buf>(buf_out_arg->accept_mutator(this));
482     TORCH_INTERNAL_ASSERT(
483         buf_out_arg_new, buildErrorMessage("IRMutator produced null for Buf."));
484     buf_out_args_new.push_back(buf_out_arg_new);
485     buf_out_args_changed |= buf_out_arg_new != buf_out_arg;
486   }
487 
488   bool buf_args_changed = false;
489   std::vector<BufPtr> buf_args_new;
490   buf_args_new.reserve(v->buf_args().size());
491   for (const auto& buf_arg : v->buf_args()) {
492     BufPtr buf_arg_new = to<Buf>(buf_arg->accept_mutator(this));
493     TORCH_INTERNAL_ASSERT(
494         buf_arg_new, buildErrorMessage("IRMutator produced null for Buf."));
495     buf_args_new.push_back(buf_arg_new);
496     buf_args_changed |= buf_arg_new != buf_arg;
497   }
498 
499   bool args_changed = false;
500   std::vector<ExprPtr> args_new;
501   args_new.reserve(v->args().size());
502   for (const auto& arg : v->args()) {
503     ExprPtr arg_new = arg->accept_mutator(this);
504     args_new.push_back(arg_new);
505     args_changed |= arg_new != arg;
506   }
507 
508   if (buf_out_args_changed) {
509     v->set_buf_out_args(buf_out_args_new);
510   }
511   if (buf_args_changed) {
512     v->set_buf_args(buf_args_new);
513   }
514   if (args_changed) {
515     v->set_args(args_new);
516   }
517   return v;
518 }
519 
mutate(const AllocatePtr & v)520 StmtPtr IRMutator::mutate(const AllocatePtr& v) {
521   BufPtr buf = v->buf();
522   BufPtr buf_new = to<Buf>(buf->accept_mutator(this));
523   TORCH_INTERNAL_ASSERT(
524       buf_new, buildErrorMessage("IRMutator produced null for Buf."));
525   if (buf != buf_new) {
526     v->set_buf(buf_new);
527   }
528   return v;
529 }
530 
mutate(const FreePtr & v)531 StmtPtr IRMutator::mutate(const FreePtr& v) {
532   BufPtr buf = v->buf();
533   BufPtr buf_new = to<Buf>(buf->accept_mutator(this));
534   TORCH_INTERNAL_ASSERT(
535       buf_new, buildErrorMessage("IRMutator produced null for Buf."));
536   if (buf != buf_new) {
537     v->set_buf(buf_new);
538   }
539   return v;
540 }
541 
mutate(const FreeExtPtr & v)542 StmtPtr IRMutator::mutate(const FreeExtPtr& v) {
543   bool bufs_changed = false;
544   std::vector<BufPtr> bufs_new;
545   bufs_new.reserve(v->bufs().size());
546   for (const auto& buf : v->bufs()) {
547     BufPtr buf_new = to<Buf>(buf->accept_mutator(this));
548     TORCH_INTERNAL_ASSERT(
549         buf_new, buildErrorMessage("IRMutator produced null for Buf."));
550     bufs_new.push_back(buf_new);
551     bufs_changed |= buf_new != buf;
552   }
553 
554   if (bufs_changed) {
555     v->set_bufs(bufs_new);
556   }
557   return v;
558 }
559 
mutate(const PlacementAllocatePtr & v)560 StmtPtr IRMutator::mutate(const PlacementAllocatePtr& v) {
561   BufPtr buf = v->buf();
562   BufPtr buf_new = to<Buf>(buf->accept_mutator(this));
563   TORCH_INTERNAL_ASSERT(
564       buf_new, buildErrorMessage("IRMutator produced null for Buf."));
565   v->set_buf(buf_new);
566 
567   BufPtr buf_to_reuse = v->buf_to_reuse();
568   BufPtr buf_to_reuse_new = to<Buf>(buf_to_reuse->accept_mutator(this));
569   TORCH_INTERNAL_ASSERT(
570       buf_to_reuse_new, buildErrorMessage("IRMutator produced null for Buf."));
571   v->set_buf_to_reuse(buf_to_reuse_new);
572 
573   return v;
574 }
575 
mutate(const LetPtr & v)576 StmtPtr IRMutator::mutate(const LetPtr& v) {
577   const VarPtr& var_old = v->var();
578   const VarPtr& var_new = to<Var>(var_old->accept_mutator(this));
579 
580   const ExprPtr& val_old = v->value();
581   const ExprPtr& val_new = val_old->accept_mutator(this);
582 
583   if (var_old != var_new) {
584     v->set_var(var_new);
585   }
586   if (val_old != val_new) {
587     v->set_val(val_new);
588   }
589   return v;
590 }
591 
mutate(const CondPtr & v)592 StmtPtr IRMutator::mutate(const CondPtr& v) {
593   ExprPtr cond_old = v->condition();
594   StmtPtr true_old = v->true_stmt();
595   StmtPtr false_old = v->false_stmt();
596 
597   ExprPtr cond_new = cond_old->accept_mutator(this);
598   StmtPtr true_new = true_old ? true_old->accept_mutator(this) : true_old;
599   StmtPtr false_new = false_old ? false_old->accept_mutator(this) : false_old;
600 
601   if (cond_old != cond_new) {
602     v->set_condition(cond_new);
603   }
604 
605   if (true_old != true_new) {
606     v->set_true_stmt(true_new);
607   }
608 
609   if (false_old != false_new) {
610     v->set_false_stmt(false_new);
611   }
612 
613   return v;
614 }
615 
616 } // namespace torch::jit::tensorexpr
617