1 #include <torch/csrc/jit/tensorexpr/ir_mutator.h>
2
3 #include <torch/csrc/jit/tensorexpr/eval.h>
4 #include <torch/csrc/jit/tensorexpr/ir.h>
5 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
6 #include <torch/csrc/jit/tensorexpr/reduction.h>
7
8 #include <c10/util/irange.h>
9
10 namespace torch::jit::tensorexpr {
11
12 template <
13 typename Op,
14 std::enable_if_t<std::is_same_v<
15 decltype(detail::bin_op_deducer(std::declval<Op>())),
16 void>>* = nullptr>
mutate_binary_op(NodePtr<Op> v,IRMutator * mutator,bool option=false)17 static ExprPtr mutate_binary_op(
18 NodePtr<Op> v,
19 IRMutator* mutator,
20 bool option = false) {
21 ExprPtr lhs = v->lhs();
22 ExprPtr rhs = v->rhs();
23 ExprPtr lhs_new = lhs->accept_mutator(mutator);
24 ExprPtr rhs_new = rhs->accept_mutator(mutator);
25 if (lhs != lhs_new) {
26 v->set_lhs(lhs_new);
27 }
28 if (rhs != rhs_new) {
29 v->set_rhs(rhs_new);
30 }
31 Dtype dtype_new =
32 BinaryOpDtype(lhs_new->dtype(), rhs_new->dtype(), ScalarType::Undefined);
33 if (dtype_new != v->dtype()) {
34 v->set_dtype(dtype_new);
35 }
36 return v;
37 }
38
mutate(const AddPtr & v)39 ExprPtr IRMutator::mutate(const AddPtr& v) {
40 return mutate_binary_op(v, this);
41 }
42
mutate(const SubPtr & v)43 ExprPtr IRMutator::mutate(const SubPtr& v) {
44 return mutate_binary_op(v, this);
45 }
46
mutate(const MulPtr & v)47 ExprPtr IRMutator::mutate(const MulPtr& v) {
48 return mutate_binary_op(v, this);
49 }
50
mutate(const DivPtr & v)51 ExprPtr IRMutator::mutate(const DivPtr& v) {
52 return mutate_binary_op(v, this);
53 }
54
mutate(const ModPtr & v)55 ExprPtr IRMutator::mutate(const ModPtr& v) {
56 return mutate_binary_op(v, this);
57 }
58
mutate(const AndPtr & v)59 ExprPtr IRMutator::mutate(const AndPtr& v) {
60 return mutate_binary_op(v, this);
61 }
62
mutate(const OrPtr & v)63 ExprPtr IRMutator::mutate(const OrPtr& v) {
64 return mutate_binary_op(v, this);
65 }
66
mutate(const XorPtr & v)67 ExprPtr IRMutator::mutate(const XorPtr& v) {
68 return mutate_binary_op(v, this);
69 }
70
mutate(const LshiftPtr & v)71 ExprPtr IRMutator::mutate(const LshiftPtr& v) {
72 return mutate_binary_op(v, this);
73 }
74
mutate(const RshiftPtr & v)75 ExprPtr IRMutator::mutate(const RshiftPtr& v) {
76 return mutate_binary_op(v, this);
77 }
78
mutate(const MaxPtr & v)79 ExprPtr IRMutator::mutate(const MaxPtr& v) {
80 return mutate_binary_op(v, this, v->propagate_nans());
81 }
82
mutate(const MinPtr & v)83 ExprPtr IRMutator::mutate(const MinPtr& v) {
84 return mutate_binary_op(v, this, v->propagate_nans());
85 }
86
mutate(const CompareSelectPtr & v)87 ExprPtr IRMutator::mutate(const CompareSelectPtr& v) {
88 ExprPtr lhs = v->lhs();
89 ExprPtr rhs = v->rhs();
90 ExprPtr ret_val1 = v->ret_val1();
91 ExprPtr ret_val2 = v->ret_val2();
92 ExprPtr lhs_new = lhs->accept_mutator(this);
93 ExprPtr rhs_new = rhs->accept_mutator(this);
94 ExprPtr ret_val1_new = ret_val1->accept_mutator(this);
95 ExprPtr ret_val2_new = ret_val2->accept_mutator(this);
96 if (lhs != lhs_new) {
97 v->set_lhs(lhs_new);
98 }
99 if (rhs != rhs_new) {
100 v->set_rhs(rhs_new);
101 }
102 if (ret_val1 != ret_val1_new) {
103 v->set_ret_val1(ret_val1_new);
104 }
105 if (ret_val2 != ret_val2_new) {
106 v->set_ret_val2(ret_val2_new);
107 }
108 return v;
109 }
110
111 // NOLINTNEXTLINE
112 #define IMM_MUTATE_DEFINE(_1, Name) \
113 ExprPtr IRMutator::mutate(const Name##ImmPtr& v) { \
114 return v; \
115 }
116 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DEFINE);
117 #undef IMM_MUTATE_DEFINE
118
mutate(const CastPtr & v)119 ExprPtr IRMutator::mutate(const CastPtr& v) {
120 ExprPtr src_value = v->src_value();
121 ExprPtr src_value_new = src_value->accept_mutator(this);
122 if (src_value != src_value_new) {
123 v->set_src_value(src_value_new);
124 }
125 return v;
126 }
127
mutate(const BitCastPtr & v)128 ExprPtr IRMutator::mutate(const BitCastPtr& v) {
129 ExprPtr src_value = v->src_value();
130 ExprPtr src_value_new = src_value->accept_mutator(this);
131 if (src_value != src_value_new) {
132 v->set_src_value(src_value_new);
133 }
134 return v;
135 }
136
mutate(const VarPtr & v)137 ExprPtr IRMutator::mutate(const VarPtr& v) {
138 return v;
139 }
140
mutate(const RampPtr & v)141 ExprPtr IRMutator::mutate(const RampPtr& v) {
142 ExprPtr base = v->base();
143 ExprPtr stride = v->stride();
144 ExprPtr base_new = base->accept_mutator(this);
145 ExprPtr stride_new = stride->accept_mutator(this);
146 if (base != base_new) {
147 v->set_base(base_new);
148 }
149 if (stride != stride_new) {
150 v->set_stride(stride_new);
151 }
152 return v;
153 }
154
mutate(const LoadPtr & v)155 ExprPtr IRMutator::mutate(const LoadPtr& v) {
156 BufPtr buf = v->buf();
157
158 bool any_index_changed = false;
159 std::vector<ExprPtr> indices_new;
160 indices_new.reserve(v->indices().size());
161 for (const ExprPtr& ind : v->indices()) {
162 ExprPtr new_ind = ind->accept_mutator(this);
163 if (new_ind != ind) {
164 any_index_changed = true;
165 }
166 indices_new.push_back(new_ind);
167 }
168 BufPtr buf_new = to<Buf>(buf->accept_mutator(this));
169
170 if (buf != buf_new) {
171 v->set_buf(buf_new);
172 }
173 if (any_index_changed) {
174 v->set_indices(indices_new);
175 }
176 return v;
177 }
178
mutate(const BufPtr & v)179 ExprPtr IRMutator::mutate(const BufPtr& v) {
180 const VarPtr& var = v->base_handle();
181 const VarPtr& var_new = to<Var>(var->accept_mutator(this));
182 if (!var_new) {
183 return nullptr;
184 }
185
186 bool dims_changed = false;
187 std::vector<ExprPtr> dims_old = v->dims();
188 std::vector<ExprPtr> dims_new(dims_old.size());
189 for (const auto i : c10::irange(dims_old.size())) {
190 dims_new[i] = dims_old[i]->accept_mutator(this);
191 dims_changed |= (dims_new[i] != dims_old[i]);
192 }
193
194 if (var != var_new) {
195 v->set_base_handle(var_new);
196 }
197 if (dims_changed) {
198 v->set_dims(dims_new);
199 }
200
201 ExprPtr qscale = v->qscale();
202 if (qscale) {
203 ExprPtr qscale_new = qscale->accept_mutator(this);
204 if (qscale != qscale_new) {
205 v->set_qscale(qscale_new);
206 }
207 }
208
209 ExprPtr qzero = v->qzero();
210 if (qzero) {
211 ExprPtr qzero_new = qzero->accept_mutator(this);
212 if (qzero != qzero_new) {
213 v->set_qzero(qzero_new);
214 }
215 }
216
217 return v;
218 }
219
mutate(const BroadcastPtr & v)220 ExprPtr IRMutator::mutate(const BroadcastPtr& v) {
221 const ExprPtr& value = v->value();
222 const ExprPtr& value_new = value->accept_mutator(this);
223 if (value != value_new) {
224 v->set_value(value_new);
225 }
226 return v;
227 }
228
mutate(const IfThenElsePtr & v)229 ExprPtr IRMutator::mutate(const IfThenElsePtr& v) {
230 ExprPtr condition = v->condition();
231 ExprPtr true_value = v->true_value();
232 ExprPtr false_value = v->false_value();
233 ExprPtr condition_new = condition->accept_mutator(this);
234 ExprPtr true_value_new = true_value->accept_mutator(this);
235 ExprPtr false_value_new = false_value->accept_mutator(this);
236
237 if (condition != condition_new) {
238 v->set_condition(condition_new);
239 }
240 if (true_value != true_value_new) {
241 v->set_true_value(true_value_new);
242 }
243 if (false_value != false_value_new) {
244 v->set_false_value(false_value_new);
245 }
246 return v;
247 }
248
mutate(const IntrinsicsPtr & v)249 ExprPtr IRMutator::mutate(const IntrinsicsPtr& v) {
250 std::vector<ExprPtr> params(v->nparams());
251 bool any_change = false;
252 for (size_t i = 0; i < v->nparams(); i++) {
253 const ExprPtr& value = v->param(i);
254 const ExprPtr& value_new = value->accept_mutator(this);
255 if (value != value_new) {
256 any_change = true;
257 }
258 params[i] = value_new;
259 }
260 if (any_change) {
261 v->set_params(params);
262 }
263 return v;
264 }
265
mutate(const TermPtr & v)266 ExprPtr IRMutator::mutate(const TermPtr& v) {
267 ExprPtr newScalar = v->scalar()->accept_mutator(this);
268
269 std::vector<ExprPtr> variables;
270 for (const auto& t : v->variables()) {
271 variables.push_back(t->accept_mutator(this));
272 }
273 return alloc<Term>(v->hasher(), newScalar, variables);
274 }
275
mutate(const PolynomialPtr & v)276 ExprPtr IRMutator::mutate(const PolynomialPtr& v) {
277 ExprPtr newScalar = v->scalar()->accept_mutator(this);
278
279 std::vector<TermPtr> variables;
280 for (const auto& t : v->variables()) {
281 variables.push_back(static_to<Term>(t->accept_mutator(this)));
282 }
283 return alloc<Polynomial>(v->hasher(), newScalar, variables);
284 }
285
mutate(const RoundOffPtr & v)286 ExprPtr IRMutator::mutate(const RoundOffPtr& v) {
287 return alloc<RoundOff>(
288 v->lhs()->accept_mutator(this), v->rhs()->accept_mutator(this));
289 }
290
mutate(const MaxTermPtr & v)291 ExprPtr IRMutator::mutate(const MaxTermPtr& v) {
292 ExprPtr newScalar = nullptr;
293 if (v->scalar()) {
294 newScalar = v->scalar()->accept_mutator(this);
295 }
296
297 std::vector<ExprPtr> variables;
298 for (const auto& t : v->variables()) {
299 variables.push_back(t->accept_mutator(this));
300 }
301 return alloc<MaxTerm>(v->hasher(), newScalar, v->propagate_nans(), variables);
302 }
303
mutate(const MinTermPtr & v)304 ExprPtr IRMutator::mutate(const MinTermPtr& v) {
305 ExprPtr newScalar = nullptr;
306 if (v->scalar()) {
307 newScalar = v->scalar()->accept_mutator(this);
308 }
309
310 std::vector<ExprPtr> variables;
311 for (const auto& t : v->variables()) {
312 variables.push_back(t->accept_mutator(this));
313 }
314 return alloc<MinTerm>(v->hasher(), newScalar, v->propagate_nans(), variables);
315 }
316
mutate(const ReduceOpPtr & v)317 ExprPtr IRMutator::mutate(const ReduceOpPtr& v) {
318 ExprPtr body_new = v->body()->accept_mutator(this);
319
320 std::vector<VarPtr> new_reduce_args;
321 for (const auto& r : v->reduce_args()) {
322 new_reduce_args.push_back(static_to<Var>(r->accept_mutator(this)));
323 }
324
325 return alloc<ReduceOp>(body_new, new_reduce_args, v->reducer());
326 }
327
mutate(const ForPtr & v)328 StmtPtr IRMutator::mutate(const ForPtr& v) {
329 const ExprPtr& var = v->var();
330 ExprPtr start = v->start();
331 ExprPtr stop = v->stop();
332 StmtPtr body = v->body();
333 LoopOptions loop_options = v->loop_options();
334 const ExprPtr& var_new_expr = var->accept_mutator(this);
335 const VarPtr& var_new = to<Var>(var_new_expr);
336 ExprPtr start_new = start->accept_mutator(this);
337 ExprPtr stop_new = stop->accept_mutator(this);
338 StmtPtr body_new = body->accept_mutator(this);
339 if (!body_new) {
340 return nullptr;
341 }
342 if (body != body_new) {
343 v->set_body(body_new);
344 }
345 if (var != var_new) {
346 v->set_var(var_new);
347 }
348 if (start != start_new) {
349 v->set_start(start_new);
350 }
351 if (stop != stop_new) {
352 v->set_stop(stop_new);
353 }
354 return v;
355 }
356
mutate(const BlockPtr & v)357 StmtPtr IRMutator::mutate(const BlockPtr& v) {
358 bool any_change = false;
359
360 std::vector<StmtPtr> stmts;
361 for (const StmtPtr& stmt : *v) {
362 StmtPtr stmt_new = stmt->accept_mutator(this);
363 if (stmt != stmt_new) {
364 any_change = true;
365 } else {
366 stmt_new = Stmt::clone(stmt);
367 }
368 if (stmt_new) {
369 stmts.push_back(stmt_new);
370 }
371 }
372 if (any_change) {
373 v->set_stmts(stmts);
374 }
375 return v;
376 }
377
mutate(const StorePtr & v)378 StmtPtr IRMutator::mutate(const StorePtr& v) {
379 BufPtr buf = v->buf();
380
381 bool any_index_changed = false;
382 std::vector<ExprPtr> indices_new;
383 for (const ExprPtr& ind : v->indices()) {
384 ExprPtr new_ind = ind->accept_mutator(this);
385 if (new_ind != ind) {
386 any_index_changed = true;
387 }
388 indices_new.push_back(new_ind);
389 }
390 const ExprPtr& value = v->value();
391 BufPtr buf_new = to<Buf>(buf->accept_mutator(this));
392 const ExprPtr& value_new = value->accept_mutator(this);
393
394 if (buf != buf_new) {
395 v->set_buf(buf_new);
396 }
397 if (any_index_changed) {
398 v->set_indices(indices_new);
399 }
400 if (value != value_new) {
401 v->set_value(value_new);
402 }
403 return v;
404 }
405
mutate(const AtomicAddPtr & v)406 StmtPtr IRMutator::mutate(const AtomicAddPtr& v) {
407 BufPtr buf = v->buf();
408
409 bool any_index_changed = false;
410 std::vector<ExprPtr> indices_new;
411 for (const ExprPtr& ind : v->indices()) {
412 ExprPtr new_ind = ind->accept_mutator(this);
413 if (new_ind != ind) {
414 any_index_changed = true;
415 }
416 indices_new.push_back(new_ind);
417 }
418 const ExprPtr& value = v->value();
419 BufPtr buf_new = to<Buf>(buf->accept_mutator(this));
420 const ExprPtr& value_new = value->accept_mutator(this);
421
422 if (buf != buf_new) {
423 v->set_buf(buf_new);
424 }
425 if (any_index_changed) {
426 v->set_indices(indices_new);
427 }
428 if (value != value_new) {
429 v->set_value(value_new);
430 }
431 return v;
432 }
433
mutate(const SyncThreadsPtr & v)434 StmtPtr IRMutator::mutate(const SyncThreadsPtr& v) {
435 return alloc<SyncThreads>();
436 }
437
mutate(const ExternalCallPtr & v)438 StmtPtr IRMutator::mutate(const ExternalCallPtr& v) {
439 BufPtr buf = v->buf();
440 BufPtr buf_new = to<Buf>(buf->accept_mutator(this));
441 TORCH_INTERNAL_ASSERT(
442 buf_new, buildErrorMessage("IRMutator produced null for Buf."));
443
444 bool buf_args_changed = false;
445 std::vector<BufPtr> buf_args_new;
446 buf_args_new.reserve(v->buf_args().size());
447 for (const BufPtr& buf_arg : v->buf_args()) {
448 BufPtr buf_arg_new = to<Buf>(buf_arg->accept_mutator(this));
449 TORCH_INTERNAL_ASSERT(
450 buf_arg_new, buildErrorMessage("IRMutator produced null for Buf."));
451 buf_args_new.push_back(buf_arg_new);
452 buf_args_changed |= buf_arg_new != buf_arg;
453 }
454
455 bool args_changed = false;
456 std::vector<ExprPtr> args_new;
457 args_new.reserve(v->args().size());
458 for (const ExprPtr& arg : v->args()) {
459 ExprPtr arg_new = arg->accept_mutator(this);
460 args_new.push_back(arg_new);
461 args_changed |= arg_new != arg;
462 }
463
464 if (buf != buf_new) {
465 v->set_buf(buf_new);
466 }
467 if (buf_args_changed) {
468 v->set_buf_args(buf_args_new);
469 }
470 if (args_changed) {
471 v->set_args(args_new);
472 }
473 return v;
474 }
475
mutate(const ExternalCallWithAllocPtr & v)476 StmtPtr IRMutator::mutate(const ExternalCallWithAllocPtr& v) {
477 bool buf_out_args_changed = false;
478 std::vector<BufPtr> buf_out_args_new;
479 buf_out_args_new.reserve(v->buf_out_args().size());
480 for (const auto& buf_out_arg : v->buf_out_args()) {
481 BufPtr buf_out_arg_new = to<Buf>(buf_out_arg->accept_mutator(this));
482 TORCH_INTERNAL_ASSERT(
483 buf_out_arg_new, buildErrorMessage("IRMutator produced null for Buf."));
484 buf_out_args_new.push_back(buf_out_arg_new);
485 buf_out_args_changed |= buf_out_arg_new != buf_out_arg;
486 }
487
488 bool buf_args_changed = false;
489 std::vector<BufPtr> buf_args_new;
490 buf_args_new.reserve(v->buf_args().size());
491 for (const auto& buf_arg : v->buf_args()) {
492 BufPtr buf_arg_new = to<Buf>(buf_arg->accept_mutator(this));
493 TORCH_INTERNAL_ASSERT(
494 buf_arg_new, buildErrorMessage("IRMutator produced null for Buf."));
495 buf_args_new.push_back(buf_arg_new);
496 buf_args_changed |= buf_arg_new != buf_arg;
497 }
498
499 bool args_changed = false;
500 std::vector<ExprPtr> args_new;
501 args_new.reserve(v->args().size());
502 for (const auto& arg : v->args()) {
503 ExprPtr arg_new = arg->accept_mutator(this);
504 args_new.push_back(arg_new);
505 args_changed |= arg_new != arg;
506 }
507
508 if (buf_out_args_changed) {
509 v->set_buf_out_args(buf_out_args_new);
510 }
511 if (buf_args_changed) {
512 v->set_buf_args(buf_args_new);
513 }
514 if (args_changed) {
515 v->set_args(args_new);
516 }
517 return v;
518 }
519
mutate(const AllocatePtr & v)520 StmtPtr IRMutator::mutate(const AllocatePtr& v) {
521 BufPtr buf = v->buf();
522 BufPtr buf_new = to<Buf>(buf->accept_mutator(this));
523 TORCH_INTERNAL_ASSERT(
524 buf_new, buildErrorMessage("IRMutator produced null for Buf."));
525 if (buf != buf_new) {
526 v->set_buf(buf_new);
527 }
528 return v;
529 }
530
mutate(const FreePtr & v)531 StmtPtr IRMutator::mutate(const FreePtr& v) {
532 BufPtr buf = v->buf();
533 BufPtr buf_new = to<Buf>(buf->accept_mutator(this));
534 TORCH_INTERNAL_ASSERT(
535 buf_new, buildErrorMessage("IRMutator produced null for Buf."));
536 if (buf != buf_new) {
537 v->set_buf(buf_new);
538 }
539 return v;
540 }
541
mutate(const FreeExtPtr & v)542 StmtPtr IRMutator::mutate(const FreeExtPtr& v) {
543 bool bufs_changed = false;
544 std::vector<BufPtr> bufs_new;
545 bufs_new.reserve(v->bufs().size());
546 for (const auto& buf : v->bufs()) {
547 BufPtr buf_new = to<Buf>(buf->accept_mutator(this));
548 TORCH_INTERNAL_ASSERT(
549 buf_new, buildErrorMessage("IRMutator produced null for Buf."));
550 bufs_new.push_back(buf_new);
551 bufs_changed |= buf_new != buf;
552 }
553
554 if (bufs_changed) {
555 v->set_bufs(bufs_new);
556 }
557 return v;
558 }
559
mutate(const PlacementAllocatePtr & v)560 StmtPtr IRMutator::mutate(const PlacementAllocatePtr& v) {
561 BufPtr buf = v->buf();
562 BufPtr buf_new = to<Buf>(buf->accept_mutator(this));
563 TORCH_INTERNAL_ASSERT(
564 buf_new, buildErrorMessage("IRMutator produced null for Buf."));
565 v->set_buf(buf_new);
566
567 BufPtr buf_to_reuse = v->buf_to_reuse();
568 BufPtr buf_to_reuse_new = to<Buf>(buf_to_reuse->accept_mutator(this));
569 TORCH_INTERNAL_ASSERT(
570 buf_to_reuse_new, buildErrorMessage("IRMutator produced null for Buf."));
571 v->set_buf_to_reuse(buf_to_reuse_new);
572
573 return v;
574 }
575
mutate(const LetPtr & v)576 StmtPtr IRMutator::mutate(const LetPtr& v) {
577 const VarPtr& var_old = v->var();
578 const VarPtr& var_new = to<Var>(var_old->accept_mutator(this));
579
580 const ExprPtr& val_old = v->value();
581 const ExprPtr& val_new = val_old->accept_mutator(this);
582
583 if (var_old != var_new) {
584 v->set_var(var_new);
585 }
586 if (val_old != val_new) {
587 v->set_val(val_new);
588 }
589 return v;
590 }
591
mutate(const CondPtr & v)592 StmtPtr IRMutator::mutate(const CondPtr& v) {
593 ExprPtr cond_old = v->condition();
594 StmtPtr true_old = v->true_stmt();
595 StmtPtr false_old = v->false_stmt();
596
597 ExprPtr cond_new = cond_old->accept_mutator(this);
598 StmtPtr true_new = true_old ? true_old->accept_mutator(this) : true_old;
599 StmtPtr false_new = false_old ? false_old->accept_mutator(this) : false_old;
600
601 if (cond_old != cond_new) {
602 v->set_condition(cond_new);
603 }
604
605 if (true_old != true_new) {
606 v->set_true_stmt(true_new);
607 }
608
609 if (false_old != false_new) {
610 v->set_false_stmt(false_new);
611 }
612
613 return v;
614 }
615
616 } // namespace torch::jit::tensorexpr
617