xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/peephole_non_tensor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/peephole.h>
2 #include <torch/csrc/jit/passes/peephole_non_tensor.h>
3 
4 #include <ATen/core/jit_type.h>
5 #include <c10/util/irange.h>
6 #include <torch/csrc/jit/ir/ir_views.h>
7 #include <torch/csrc/jit/jit_log.h>
8 
9 namespace torch::jit {
10 
11 namespace {
12 
13 /**
14  * Check whether the arithmetic node is binary between integers, and return a
15  * constant int value if there exists one.
16  *
17  * @pre node is integer arithmetic.
18  * @post if there's one constant in two operands, then the second operand is
19  *       constant.
20  */
checkArithNode(Node & node)21 std::optional<int64_t> checkArithNode(Node& node) {
22   if (node.inputs().size() != 2 || node.input(0)->type() != IntType::get() ||
23       node.input(1)->type() != IntType::get()) {
24     return {};
25   }
26 
27   if (node.kind() == aten::mul || node.kind() == aten::add) {
28     if (auto i = constant_as<int64_t>(node.input(0))) {
29       node.permuteInputs({1, 0});
30       return i;
31     }
32   }
33 
34   return constant_as<int64_t>(node.input(1));
35 }
36 
37 /**
38  * Remove a mul/floordiv node if it is multiplication or division by 1.
39  *
40  * @pre node is either aten::mul, aten::floordiv or aten::div
41  */
trySimplifyMulOrDiv(Node & node)42 bool trySimplifyMulOrDiv(Node& node) {
43   auto constant = checkArithNode(node);
44   if (!constant || *constant != 1) {
45     return false;
46   }
47 
48   node.output()->replaceAllUsesWith(node.inputs()[0]);
49   return true;
50 }
51 
52 /**
53  * Simplify an add/sub node with its input node, i.e. merge the constant parts
54  * together.
55  *
56  * @pre node is either aten::add or aten::sub
57  */
trySimplifyAddOrSub(Node & node)58 bool trySimplifyAddOrSub(Node& node) {
59   auto constant = checkArithNode(node);
60   if (!constant) {
61     return false;
62   }
63 
64   if (constant == 0) {
65     node.output()->replaceAllUsesWith(node.input(0));
66     return true;
67   }
68 
69   auto& dep = *node.inputs()[0]->node();
70   if (dep.kind() != aten::add && dep.kind() != aten::sub) {
71     return false;
72   }
73 
74   auto delta = checkArithNode(dep);
75   if (!delta) {
76     return false;
77   }
78   auto merged =
79       dep.kind() == node.kind() ? *constant + *delta : *constant - *delta;
80 
81   if (merged == 0) {
82     node.output()->replaceAllUsesWith(dep.inputs()[0]);
83   } else {
84     WithInsertPoint g(&node);
85     node.replaceInput(0, dep.inputs()[0]);
86     node.replaceInput(1, node.owningGraph()->insertConstant(merged));
87   }
88   return true;
89 }
90 
91 } // namespace
92 
93 struct PeepholeOptimizeNonTensorImpl {
PeepholeOptimizeNonTensorImpltorch::jit::PeepholeOptimizeNonTensorImpl94   PeepholeOptimizeNonTensorImpl(std::shared_ptr<Graph> graph)
95       : graph_(std::move(graph)) {}
96 
runtorch::jit::PeepholeOptimizeNonTensorImpl97   bool run() {
98     return optimizeBlock(graph_->block());
99   }
100 
optimizeBlocktorch::jit::PeepholeOptimizeNonTensorImpl101   bool optimizeBlock(Block* block) {
102     bool changed = false;
103     for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
104       auto* node = *it;
105 
106       for (Block* sub_block : node->blocks()) {
107         changed |= optimizeBlock(sub_block);
108       }
109 
110       if (node->kind() != prim::Constant) {
111         WithInsertPoint guard(node);
112         // Any Value whose type is None should be replaced with a Constant
113         // This can occur if a module has an optional attribute, and it is
114         // initialized as None.
115         for (Value* output : node->outputs()) {
116           if (output->type()->cast<NoneType>()) {
117             output->replaceAllUsesWith(graph_->insertConstant(IValue()));
118             changed = true;
119           }
120         }
121       }
122       // XXX: remember that if you want to simplify an expression by combining
123       // multiple nodes into a different one, then you need to check that they
124       // all belong to the given block
125       // TODO: this doesn't work with Scalar-Tensor ops! We should
126       // canonicalize those
127       if (node->kind() == prim::If) {
128         IfView n(node);
129         // this handles redundant short circuits like "x and True" or "x or
130         // False"
131         for (const auto i : c10::irange(n.outputs().size())) {
132           if (n.outputs().at(i)->type() != BoolType::get()) {
133             continue;
134           }
135           bool true_val =
136               constant_as<bool>(n.thenOutputs().at(i)).value_or(false);
137           bool false_val =
138               constant_as<bool>(n.elseOutputs().at(i)).value_or(true);
139           // if an if node's output equals its condition replace output with
140           // condition
141           if (true_val && !false_val) {
142             GRAPH_UPDATE(
143                 "Replacing ",
144                 n.outputs().at(i)->debugName(),
145                 " (True or False) with ",
146                 n.cond()->debugName());
147             n.outputs().at(i)->replaceAllUsesWith(n.cond());
148             changed = true;
149           }
150         }
151 
152         // check for types that can be refined
153         for (size_t i = 0; i < n.outputs().size(); ++i) {
154           // common case of optional for now
155           bool inputs_non_optional =
156               !n.thenOutputs().at(i)->type()->cast<OptionalType>() &&
157               !n.elseOutputs().at(i)->type()->cast<OptionalType>();
158           auto output_optional =
159               n.outputs().at(i)->type()->cast<OptionalType>();
160           if (inputs_non_optional && output_optional) {
161             if (auto unif = unifyTypes(
162                     n.thenOutputs().at(i)->type(),
163                     n.elseOutputs().at(i)->type())) {
164               n.outputs().at(i)->setType(*unif);
165               changed = true;
166             }
167           }
168         }
169       } else if (
170           node->kind() == aten::__is__ || node->kind() == aten::__isnot__) {
171         // if we are comparing a None value with a value that can't be None
172         // replace the output with true if node is __isnot__ or false if node is
173         // __is__
174         AT_ASSERT(node->inputs().size() == 2);
175         for (size_t check_none_index : {0, 1}) {
176           bool input_must_be_none =
177               node->inputs().at(check_none_index)->mustBeNone();
178           bool other_must_not_be_none =
179               node->inputs().at(1 - check_none_index)->mustNotBeNone();
180           if (input_must_be_none && other_must_not_be_none) {
181             WithInsertPoint guard(node);
182             auto output = node->owningGraph()->insertConstant(
183                 node->kind() == aten::__isnot__);
184             GRAPH_UPDATE(
185                 "Folding ", getHeader(node), " to ", output->debugName());
186             node->output()->replaceAllUsesWith(output);
187             changed = true;
188           }
189         }
190       } else if (
191           node->kind() == prim::unchecked_unwrap_optional ||
192           node->kind() == aten::_unwrap_optional) {
193         // we are unwrapping an input that can't be None, remove the unwrap
194         auto input = node->input();
195         if (input->mustNotBeNone()) {
196           GRAPH_UPDATE(
197               "Unwrapping ",
198               getHeader(node),
199               " as ",
200               node->input(),
201               " can't be optional");
202           node->output()->replaceAllUsesWith(node->input());
203           changed = true;
204         }
205       } else if (node->kind() == prim::unchecked_cast) {
206         // unchecked_cast is not generated for tensor properties, so we are not
207         // losing anything by calling unshapedType here
208         auto input_type = unshapedType(node->input()->type());
209         auto output_type = unshapedType(node->output()->type());
210         if (input_type->isSubtypeOf(*output_type)) {
211           GRAPH_UPDATE(
212               "Removing ",
213               getHeader(node),
214               " as input type subtypes output type");
215           node->output()->replaceAllUsesWith(node->input());
216           changed = true;
217         }
218       } else if (
219           (node->kind() == aten::Int || node->kind() == aten::ceil) &&
220           node->inputs().size() == 1 &&
221           node->input()->type()->cast<IntType>()) {
222         GRAPH_UPDATE(
223             "Removing ", getHeader(node), " as input is already an integer");
224         node->output()->replaceAllUsesWith(node->input());
225         changed = true;
226       } else if (node->kind() == aten::ne || node->kind() == aten::eq) {
227         if (node->inputs().size() != 2 ||
228             node->inputs().at(0) != node->inputs().at(1)) {
229           continue;
230         }
231         auto inp_type = node->inputs().at(0)->type();
232         // only handling common immutable types here because other types like
233         // Tensor or list of Tensor might throw on aten::eq
234         auto immut_type = [&](const TypePtr& type) {
235           auto kind = type->kind();
236           static const std::vector<TypeKind> handled_immutable_types = {
237               TypeKind::BoolType,
238               TypeKind::IntType,
239               TypeKind::FloatType,
240               TypeKind::NoneType};
241           return (
242               std::find(
243                   handled_immutable_types.begin(),
244                   handled_immutable_types.end(),
245                   kind) != handled_immutable_types.end());
246         };
247         bool non_throwing_type = false;
248         if (auto li_type = inp_type->cast<ListType>()) {
249           non_throwing_type = immut_type(li_type->getElementType());
250         } else if (auto di_type = inp_type->cast<DictType>()) {
251           non_throwing_type =
252               (immut_type(di_type->getKeyType()) &&
253                immut_type(di_type->getValueType()));
254         } else {
255           non_throwing_type = immut_type(inp_type);
256         }
257         if (non_throwing_type) {
258           WithInsertPoint guard(node);
259           node->output()->replaceAllUsesWith(
260               graph_->insertConstant(node->kind() == aten::eq));
261           changed = true;
262         }
263       } else if (
264           node->kind() == aten::mul || node->kind() == aten::floordiv ||
265           node->kind() == aten::div) {
266         changed |= trySimplifyMulOrDiv(*node);
267       } else if (node->kind() == aten::add || node->kind() == aten::sub) {
268         changed |= trySimplifyAddOrSub(*node);
269       }
270     }
271     return changed;
272   }
273 
274  private:
275   std::shared_ptr<Graph> graph_;
276 };
277 
PeepholeOptimizeNonTensor(const std::shared_ptr<Graph> & graph)278 bool PeepholeOptimizeNonTensor(const std::shared_ptr<Graph>& graph) {
279   PeepholeOptimizeNonTensorImpl peephole(graph);
280   bool changed = peephole.run();
281   GRAPH_DUMP("After PeepholeOptimize: ", graph);
282   return changed;
283 }
284 
285 } // namespace torch::jit
286