1 #include <torch/csrc/jit/passes/normalize_ops.h>
2
3 #include <c10/util/Exception.h>
4
5 namespace torch::jit {
6
7 namespace {
8
9 // having multiple ops in our IR that do the same thing makes the IR more
10 // difficult to consumer for downstream user of the IR, such as our own
11 // optimization passes here, we convert op aliases into a standard form
normalizeOpAliases(graph_node_list_iterator & iter)12 bool normalizeOpAliases(graph_node_list_iterator& iter) {
13 auto alias = getOperatorAliasMap().find(iter->kind());
14 if (alias != getOperatorAliasMap().end()) {
15 iter->replaceWithNewSymbol(alias->second);
16 iter.destroyCurrent();
17 return true;
18 }
19 return false;
20 }
21
22 // Normalize rsub such that `rsub(x,y) = sub(x,y)`
normalizeRSub(graph_node_list_iterator & iter)23 bool normalizeRSub(graph_node_list_iterator& iter) {
24 if (iter->matches(
25 "aten::rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor")) {
26 ArrayRef<Value*> args = iter->inputs();
27 Node* newSub = iter->replaceWithNewSymbol(aten::sub);
28 newSub->replaceInput(0, args[1]);
29 newSub->replaceInput(1, args[0]);
30 iter.destroyCurrent();
31 return true;
32 }
33 return false;
34 }
35
36 // Normalizes a `__is__` comparison with a bool to `eq` (and same with
37 // `__isnot__`)
normalizeIsBool(graph_node_list_iterator & iter)38 bool normalizeIsBool(graph_node_list_iterator& iter) {
39 ArrayRef<Value*> args = iter->inputs();
40 if (args.size() == 2 && args[0]->type() == BoolType::get() &&
41 args[1]->type() == BoolType::get()) {
42 if (iter->kind() == aten::__is__) {
43 iter->replaceWithNewSymbol(aten::eq);
44 iter.destroyCurrent();
45 return true;
46 }
47 if (iter->kind() == aten::__isnot__) {
48 iter->replaceWithNewSymbol(aten::ne);
49 iter.destroyCurrent();
50 return true;
51 }
52 }
53 return false;
54 }
55
NormalizeOps(Block * block)56 void NormalizeOps(Block* block) {
57 for (auto it = block->nodes().begin(), end = block->nodes().end();
58 it != end;) {
59 for (auto sub : it->blocks()) {
60 NormalizeOps(sub);
61 }
62
63 if (normalizeRSub(it)) {
64 continue;
65 }
66
67 if (normalizeOpAliases(it)) {
68 continue;
69 }
70
71 if (normalizeIsBool(it)) {
72 continue;
73 }
74
75 it++;
76 }
77 }
78
79 } // namespace
80
getOperatorAliasMap()81 const std::unordered_map<Symbol, Symbol>& getOperatorAliasMap() {
82 // map from op alias -> normalized op
83 static const std::unordered_map<Symbol, Symbol> alias_map = {
84 {aten::absolute, aten::abs},
85 {aten::absolute_, aten::abs_},
86 {aten::clip, aten::clamp},
87 {aten::clip_, aten::clamp_},
88 {aten::det, aten::linalg_det},
89 {aten::matrix_power, aten::linalg_matrix_power},
90 {aten::matrix_exp, aten::linalg_matrix_exp},
91 {aten::ger, aten::outer},
92 {aten::arccos, aten::acos},
93 {aten::arccos_, aten::acos_},
94 {aten::arcsin, aten::asin},
95 {aten::arcsin_, aten::asin_},
96 {aten::arctan, aten::atan},
97 {aten::arctan_, aten::atan_},
98 {aten::arctan2, aten::atan2},
99 {aten::arctan2_, aten::atan2_},
100 {aten::arccosh, aten::acosh},
101 {aten::arccosh_, aten::acosh_},
102 {aten::arcsinh, aten::asinh},
103 {aten::arcsinh_, aten::asinh_},
104 {aten::arctanh, aten::atanh},
105 {aten::arctanh_, aten::atanh_},
106 {aten::fix, aten::trunc},
107 {aten::fix_, aten::trunc_},
108 {aten::negative, aten::neg},
109 {aten::negative_, aten::neg_},
110 {aten::subtract, aten::sub},
111 {aten::subtract_, aten::sub_},
112 {aten::greater_equal, aten::ge},
113 {aten::greater_equal_, aten::ge_},
114 {aten::greater, aten::gt},
115 {aten::greater_, aten::gt_},
116 {aten::less_equal, aten::le},
117 {aten::less_equal_, aten::le_},
118 {aten::less, aten::lt},
119 {aten::less_, aten::lt_},
120 {aten::not_equal, aten::ne},
121 {aten::not_equal_, aten::ne_},
122 {aten::divide, aten::div},
123 {aten::divide_, aten::div_},
124 {aten::multiply, aten::mul},
125 {aten::multiply_, aten::mul_},
126 {aten::linalg_matmul, aten::matmul},
127 {aten::inverse, aten::linalg_inv},
128 {aten::true_divide, aten::div},
129 {aten::true_divide_, aten::div_},
130 {aten::concat, aten::cat},
131 {aten::concatenate, aten::cat},
132 {aten::row_stack, aten::vstack},
133 {aten::swapdims, aten::transpose},
134 {aten::swapdims_, aten::transpose_},
135 {aten::swapaxes, aten::transpose},
136 {aten::swapaxes_, aten::transpose_},
137 {aten::moveaxis, aten::movedim},
138 {aten::special_erf, aten::erf},
139 {aten::special_erfc, aten::erfc},
140 {aten::special_erfinv, aten::erfinv},
141 {aten::special_expit, aten::sigmoid},
142 {aten::special_exp2, aten::exp2},
143 {aten::special_expm1, aten::expm1},
144 {aten::special_logit, aten::logit},
145 {aten::special_logsumexp, aten::logsumexp},
146 {aten::special_round, aten::round},
147 {aten::special_log1p, aten::log1p},
148 {aten::special_sinc, aten::sinc},
149 {aten::special_digamma, aten::digamma},
150 {aten::special_psi, aten::digamma},
151 {aten::special_i0, aten::i0},
152 {aten::special_xlogy, aten::xlogy},
153 {aten::special_log_softmax, aten::log_softmax},
154 {aten::orgqr, aten::linalg_householder_product},
155 {aten::adjoint, aten::mH},
156 {aten::special_multigammaln, aten::mvlgamma},
157 {aten::special_polygamma, aten::polygamma},
158 {aten::special_softmax, aten::softmax},
159 {aten::special_gammainc, aten::igamma},
160 {aten::special_gammaincc, aten::igammac},
161 {aten::special_gammaln, aten::lgamma}};
162 return alias_map;
163 }
164
NormalizeOps(const std::shared_ptr<Graph> & graph)165 void NormalizeOps(const std::shared_ptr<Graph>& graph) {
166 NormalizeOps(graph->block());
167 }
168
169 } // namespace torch::jit
170