xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/normalize_ops.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/normalize_ops.h>
2 
3 #include <c10/util/Exception.h>
4 
5 namespace torch::jit {
6 
7 namespace {
8 
9 // having multiple ops in our IR that do the same thing makes the IR more
10 // difficult to consumer for downstream user of the IR, such as our own
11 // optimization passes here, we convert op aliases into a standard form
normalizeOpAliases(graph_node_list_iterator & iter)12 bool normalizeOpAliases(graph_node_list_iterator& iter) {
13   auto alias = getOperatorAliasMap().find(iter->kind());
14   if (alias != getOperatorAliasMap().end()) {
15     iter->replaceWithNewSymbol(alias->second);
16     iter.destroyCurrent();
17     return true;
18   }
19   return false;
20 }
21 
22 // Normalize rsub such that `rsub(x,y) = sub(x,y)`
normalizeRSub(graph_node_list_iterator & iter)23 bool normalizeRSub(graph_node_list_iterator& iter) {
24   if (iter->matches(
25           "aten::rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor")) {
26     ArrayRef<Value*> args = iter->inputs();
27     Node* newSub = iter->replaceWithNewSymbol(aten::sub);
28     newSub->replaceInput(0, args[1]);
29     newSub->replaceInput(1, args[0]);
30     iter.destroyCurrent();
31     return true;
32   }
33   return false;
34 }
35 
36 // Normalizes a `__is__` comparison with a bool to `eq` (and same with
37 // `__isnot__`)
normalizeIsBool(graph_node_list_iterator & iter)38 bool normalizeIsBool(graph_node_list_iterator& iter) {
39   ArrayRef<Value*> args = iter->inputs();
40   if (args.size() == 2 && args[0]->type() == BoolType::get() &&
41       args[1]->type() == BoolType::get()) {
42     if (iter->kind() == aten::__is__) {
43       iter->replaceWithNewSymbol(aten::eq);
44       iter.destroyCurrent();
45       return true;
46     }
47     if (iter->kind() == aten::__isnot__) {
48       iter->replaceWithNewSymbol(aten::ne);
49       iter.destroyCurrent();
50       return true;
51     }
52   }
53   return false;
54 }
55 
NormalizeOps(Block * block)56 void NormalizeOps(Block* block) {
57   for (auto it = block->nodes().begin(), end = block->nodes().end();
58        it != end;) {
59     for (auto sub : it->blocks()) {
60       NormalizeOps(sub);
61     }
62 
63     if (normalizeRSub(it)) {
64       continue;
65     }
66 
67     if (normalizeOpAliases(it)) {
68       continue;
69     }
70 
71     if (normalizeIsBool(it)) {
72       continue;
73     }
74 
75     it++;
76   }
77 }
78 
79 } // namespace
80 
getOperatorAliasMap()81 const std::unordered_map<Symbol, Symbol>& getOperatorAliasMap() {
82   // map from op alias -> normalized op
83   static const std::unordered_map<Symbol, Symbol> alias_map = {
84       {aten::absolute, aten::abs},
85       {aten::absolute_, aten::abs_},
86       {aten::clip, aten::clamp},
87       {aten::clip_, aten::clamp_},
88       {aten::det, aten::linalg_det},
89       {aten::matrix_power, aten::linalg_matrix_power},
90       {aten::matrix_exp, aten::linalg_matrix_exp},
91       {aten::ger, aten::outer},
92       {aten::arccos, aten::acos},
93       {aten::arccos_, aten::acos_},
94       {aten::arcsin, aten::asin},
95       {aten::arcsin_, aten::asin_},
96       {aten::arctan, aten::atan},
97       {aten::arctan_, aten::atan_},
98       {aten::arctan2, aten::atan2},
99       {aten::arctan2_, aten::atan2_},
100       {aten::arccosh, aten::acosh},
101       {aten::arccosh_, aten::acosh_},
102       {aten::arcsinh, aten::asinh},
103       {aten::arcsinh_, aten::asinh_},
104       {aten::arctanh, aten::atanh},
105       {aten::arctanh_, aten::atanh_},
106       {aten::fix, aten::trunc},
107       {aten::fix_, aten::trunc_},
108       {aten::negative, aten::neg},
109       {aten::negative_, aten::neg_},
110       {aten::subtract, aten::sub},
111       {aten::subtract_, aten::sub_},
112       {aten::greater_equal, aten::ge},
113       {aten::greater_equal_, aten::ge_},
114       {aten::greater, aten::gt},
115       {aten::greater_, aten::gt_},
116       {aten::less_equal, aten::le},
117       {aten::less_equal_, aten::le_},
118       {aten::less, aten::lt},
119       {aten::less_, aten::lt_},
120       {aten::not_equal, aten::ne},
121       {aten::not_equal_, aten::ne_},
122       {aten::divide, aten::div},
123       {aten::divide_, aten::div_},
124       {aten::multiply, aten::mul},
125       {aten::multiply_, aten::mul_},
126       {aten::linalg_matmul, aten::matmul},
127       {aten::inverse, aten::linalg_inv},
128       {aten::true_divide, aten::div},
129       {aten::true_divide_, aten::div_},
130       {aten::concat, aten::cat},
131       {aten::concatenate, aten::cat},
132       {aten::row_stack, aten::vstack},
133       {aten::swapdims, aten::transpose},
134       {aten::swapdims_, aten::transpose_},
135       {aten::swapaxes, aten::transpose},
136       {aten::swapaxes_, aten::transpose_},
137       {aten::moveaxis, aten::movedim},
138       {aten::special_erf, aten::erf},
139       {aten::special_erfc, aten::erfc},
140       {aten::special_erfinv, aten::erfinv},
141       {aten::special_expit, aten::sigmoid},
142       {aten::special_exp2, aten::exp2},
143       {aten::special_expm1, aten::expm1},
144       {aten::special_logit, aten::logit},
145       {aten::special_logsumexp, aten::logsumexp},
146       {aten::special_round, aten::round},
147       {aten::special_log1p, aten::log1p},
148       {aten::special_sinc, aten::sinc},
149       {aten::special_digamma, aten::digamma},
150       {aten::special_psi, aten::digamma},
151       {aten::special_i0, aten::i0},
152       {aten::special_xlogy, aten::xlogy},
153       {aten::special_log_softmax, aten::log_softmax},
154       {aten::orgqr, aten::linalg_householder_product},
155       {aten::adjoint, aten::mH},
156       {aten::special_multigammaln, aten::mvlgamma},
157       {aten::special_polygamma, aten::polygamma},
158       {aten::special_softmax, aten::softmax},
159       {aten::special_gammainc, aten::igamma},
160       {aten::special_gammaincc, aten::igammac},
161       {aten::special_gammaln, aten::lgamma}};
162   return alias_map;
163 }
164 
NormalizeOps(const std::shared_ptr<Graph> & graph)165 void NormalizeOps(const std::shared_ptr<Graph>& graph) {
166   NormalizeOps(graph->block());
167 }
168 
169 } // namespace torch::jit
170