xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/guard_elimination.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/guard_elimination.h>
2 
3 #include <torch/csrc/jit/ir/alias_analysis.h>
4 #include <torch/csrc/jit/jit_log.h>
5 #include <torch/csrc/jit/passes/constant_propagation.h>
6 #include <torch/csrc/jit/passes/peephole.h>
7 #include <torch/csrc/jit/runtime/graph_executor.h>
8 #include <memory>
9 #include <unordered_set>
10 
11 namespace torch::jit {
12 
13 struct GuardElimination {
GuardEliminationtorch::jit::GuardElimination14   GuardElimination(std::shared_ptr<Graph> graph)
15       : graph_(std::move(graph)), aliasDb_(std::make_unique<AliasDb>(graph_)) {}
16 
runtorch::jit::GuardElimination17   void run() {
18     const size_t MAX_ATTEMPTS = 5;
19     size_t attempts = MAX_ATTEMPTS;
20     while (attempts-- && moveGuardsToDefs(graph_->block())) {
21     }
22     GRAPH_DUMP("After moveGuardsToDefs", graph_);
23     coalesceGuards(graph_->block());
24     GRAPH_DUMP("After coalesceGuards", graph_);
25     removeDominatedGuards(graph_->block());
26     GRAPH_DUMP("After removeDominatedGuards", graph_);
27     eliminateRedundantGuards(graph_->block());
28     GRAPH_DUMP("After eliminateRedundantGuards", graph_);
29   }
30 
isLoweredGradOftorch::jit::GuardElimination31   static bool isLoweredGradOf(Node* n) {
32     if (n->kind() != prim::If) {
33       return false;
34     }
35 
36     return n->input(0)->node()->kind() == prim::AutogradAnyNonZero;
37   }
38 
moveGuardsToDefstorch::jit::GuardElimination39   bool moveGuardsToDefs(Block* b) {
40     bool changed = false;
41     for (auto it = b->nodes().begin(); it != b->nodes().end();) {
42       auto n = *it;
43       if (n->kind() == prim::Guard) {
44         // grab the next node before we move this one all the way back
45         it++;
46         auto guardee = n->inputs().at(0)->node();
47         // alias analysis will try to hoist a node out of a loop
48         // if asked. if guardee is in a loop, it should only
49         // be moved to the beginning of the basic block
50         // given the current implementation of AliasAnalysis
51         if (guardee->owningBlock() != n->owningBlock()) {
52           guardee = *n->owningBlock()->nodes().begin();
53         }
54         bool moved = aliasDb_->moveAfterTopologicallyValid(n, guardee);
55         changed |= moved;
56         if (moved) {
57           GRAPH_UPDATE(
58               "Moved ",
59               n->output()->debugName(),
60               " to ",
61               n->inputs().at(0)->debugName());
62         }
63       } else {
64         it++;
65         for (Block* ib : n->blocks()) {
66           moveGuardsToDefs(ib);
67         }
68       }
69     }
70 
71     if (b->owningNode() &&
72         isLoweredGradOf(
73             b->owningNode()) /*b->owningNode()->kind() == prim::If*/) {
74       for (auto it = b->nodes().begin(); it != b->nodes().end();) {
75         auto block_node = *it++;
76         if (block_node->kind() != prim::Guard) {
77           break;
78         }
79         block_node->moveBefore(b->owningNode());
80         changed = true;
81       }
82     }
83 
84     return changed;
85   }
86 
coalesceGuardstorch::jit::GuardElimination87   void coalesceGuards(Block* b) {
88     // uses on *all* parameters are moved to the same anchor node
89     // and they may come in different order after the anchor node
90     // e.g. (anchor, guard_x, guard_y, guard_x, guard_y)
91     // this pass recognizes contiguous stretches of guards and
92     // keeps track of the guards it's seen for each def. the next time
93     // the guard on the same def, it simply removes it.
94     std::unordered_map<Value*, Node*> inputs_to_guards;
95     for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
96       auto n = *it;
97       if (n->kind() == prim::Guard) {
98         if (inputs_to_guards.count(n->input())) {
99           auto prev = inputs_to_guards[n->input()];
100           n->output()->replaceAllUsesWith(prev->output());
101           GRAPH_UPDATE(
102               "Replacing ",
103               n->output()->debugName(),
104               " with ",
105               prev->output()->debugName());
106           it.destroyCurrent();
107         } else {
108           inputs_to_guards.insert({n->input(), n});
109         }
110       } else if (n->kind() != prim::Constant) {
111         inputs_to_guards.clear();
112         for (Block* ib : n->blocks()) {
113           coalesceGuards(ib);
114         }
115       }
116     }
117   }
118 
removeDominatedGuardstorch::jit::GuardElimination119   void removeDominatedGuards(Block* b) {
120     // If a Node guards a value which isn't mutated, then that node
121     // can replace all other guards of the value which it dominates
122     for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
123       auto n = *it;
124       if (n->kind() == prim::Guard) {
125         Value* input = n->input();
126         if (aliasDb_->hasWriters(input)) {
127           continue;
128         }
129         Value* guard_output = n->output();
130 
131         // find all uses of the input that the guard node dominates
132         std::vector<Use> uses = input->uses();
133         while (!uses.empty()) {
134           auto use = uses.at(uses.size() - 1);
135           uses.pop_back();
136 
137           // not all uses are guarded
138           if (use.user->kind() != prim::Guard) {
139             continue;
140           }
141 
142           if (!use.user->isDominatedBy(n)) {
143             continue;
144           }
145 
146           // the dominated guard type may be different from the dominator
147           // if it is only executed for a subtype, or if it is executed
148           // in a different global context for grad enabled
149           // check that the types are equal before continuing
150 
151           auto dominator_type = guard_output->type();
152           auto dominated_type = use.user->output()->type();
153 
154           if (*dominator_type == *dominated_type) {
155             use.user->replaceInput(use.offset, guard_output);
156           }
157         }
158 
159         // remove redundant dominated guards
160         std::vector<Use> users = n->output()->uses();
161         for (auto use : users) {
162           auto user = use.user;
163           if (user->kind() == prim::Guard) {
164             GRAPH_UPDATE(
165                 "Removing dominated guard ", user, " and replacing with ", n);
166             user->output()->replaceAllUsesWith(guard_output);
167             user->destroy();
168           }
169         }
170       } else {
171         for (Block* ib : n->blocks()) {
172           removeDominatedGuards(ib);
173         }
174       }
175     }
176   }
177 
178   // we need to make sure there are no ops in between guardee's
179   // output and its guard except for other guards as they can
180   // invalidate shape information.
guardsOutputtorch::jit::GuardElimination181   bool guardsOutput(Node* guard) {
182     auto output = guard->input()->node();
183     auto it = guard;
184     while (it != output) {
185       if (it->kind() != prim::Guard && it->kind() != prim::Constant) {
186         GRAPH_DEBUG(
187             "found an unexpected node ",
188             *it,
189             " while trying to eliminate ",
190             *guard);
191         return false;
192       }
193       it = it->prev();
194     }
195 
196     return true;
197   }
198 
eliminateRedundantGuardstorch::jit::GuardElimination199   void eliminateRedundantGuards(Block* b) {
200     // a very simple pass to eliminate redundant guards for ops
201     // whose outputs are fully determined by their inputs
202     // i.e. if inputs to such ops are guarded we are allowed
203     // to remove a guard on ops' outputs
204     for (auto it = b->nodes().rbegin(); it != b->nodes().rend();) {
205       auto n = *it;
206       if (n->kind() == prim::Guard && guardsOutput(n) &&
207           removableGuard(n->inputs().at(0)->node())) {
208         auto pttp = n->output()->type();
209         n->output()->replaceAllUsesWith(n->inputs().at(0));
210         n->inputs().at(0)->setType(pttp);
211         GRAPH_UPDATE(
212             "Eliminating the redundant guard ", n->output()->debugName());
213         it.destroyCurrent();
214       } else {
215         it++;
216         for (Block* ib : n->blocks()) {
217           eliminateRedundantGuards(ib);
218         }
219       }
220     }
221   }
222 
223   // `checkInputs` check the invariants specified in `removableGuard`
224   // on inputs to `n`. The invariants must hold, or an input must
225   // be a `prim::Constant` or be included as an exception in `except`
checkInputstorch::jit::GuardElimination226   bool checkInputs(
227       Node* n,
228       const std::unordered_set<size_t>& except,
229       bool allow_numbers) {
230     bool all_inputs_guarded = true;
231     size_t i = 0;
232     for (auto input : n->inputs()) {
233       if ((input->node()->kind() == prim::Guard &&
234            !input->type()->expectRef<TensorType>().isSummarized()) ||
235           input->node()->kind() == prim::Constant ||
236           (allow_numbers && input->type()->isSubtypeOf(*NumberType::get())) ||
237           except.count(i) != 0) {
238         AT_ASSERT(
239             input->node()->kind() != prim::Guard ||
240             input->type()->expect<TensorType>());
241       } else {
242         GRAPH_DEBUG(
243             "input ",
244             input->debugName(),
245             " isn't guarded, type ",
246             *input->type());
247         all_inputs_guarded = false;
248         break;
249       }
250       i++;
251     }
252     return all_inputs_guarded;
253   }
254 
255  private:
256   // `removableGuard` relies on the properties checked by `isSummarized()`
257   // and passes shouldn't insert nodes between a guard and its uses that
258   // may alter those properties.
259   // `removableGuard` expects type information to come directly from
260   // Profiler. Passes shouldn't try to alter type information provided by
261   // profiling
262   // While we can derive very simple rules stating when it's valid to remove
263   // `prim::Guard` on operation's output if all of its inputs are guarded for
264   // some
265   // categories of operations
266   // there's no comprehensive set of rules that covers all the operations
267   // available in PyTorch
268   // If your operation falls into one of the categories described below, you
269   // should add it
270   // to switch statement below that contains the other operations in the said
271   // category.
272   // Otherwise, you will need to derive the rules for your case on your own.
273   // Generally, any operation that is stateful in any way or uses its underlying
274   // data
275   // to compute any properties `isSummarized()` isn't amenable to guard
276   // elimination.
277   // Categories:
278   // * Functional-like(e.g. add, sub, le) operations with broadcast semenatics
279   //   Guards can be removed if all inputs are guarded and `isSummarized()`
280   //   returns
281   //   false or inputs are `prim::Constant`
removableGuardtorch::jit::GuardElimination282   bool removableGuard(Node* n) {
283     const static auto no_exceptions = std::unordered_set<size_t>{};
284     switch (n->kind()) {
285       case aten::add:
286       case aten::add_:
287       case aten::sub:
288       case aten::mul:
289       case aten::div:
290       case aten::t:
291       case aten::sigmoid:
292       case aten::sin:
293       case aten::cos:
294       case aten::tan:
295       case aten::sinh:
296       case aten::cosh:
297       case aten::tanh:
298       case aten::asin:
299       case aten::acos:
300       case aten::atan:
301       case aten::atan2:
302       case aten::floor:
303       case aten::fmod:
304       case aten::ceil:
305       case aten::trunc:
306       case aten::sqrt:
307       case aten::rsqrt:
308       case aten::remainder:
309       case aten::mm:
310       case aten::min:
311       case aten::max:
312       case aten::type_as:
313       case aten::ge:
314       case aten::gt:
315       case aten::lt:
316       case aten::le:
317       case aten::eq:
318       case aten::ne:
319       case aten::neg:
320       case prim::ConstantChunk:
321       case aten::size:
322       case aten::abs:
323       case aten::sign:
324       case aten::pow:
325       case aten::relu:
326       case aten::threshold:
327       case prim::AutogradAdd:
328       case prim::AutogradZero:
329       case aten::rand_like:
330       case aten::erf:
331       case aten::erfc:
332       case aten::exp:
333       case aten::expm1:
334       case aten::log:
335       case aten::log2:
336       case aten::log10:
337       case aten::frac:
338       case aten::lerp:
339       case aten::lgamma:
340       case aten::reciprocal:
341       case aten::addcmul:
342       case aten::where:
343       case aten::_cast_Float:
344       case aten::_cast_Long:
345       case aten::__and__:
346       case aten::__or__:
347       case aten::__xor__:
348       case aten::__lshift__:
349       case aten::__rshift__:
350       case aten::bitwise_not:
351       case aten::bitwise_and:
352       case aten::bitwise_or:
353       case aten::bitwise_xor:
354         return checkInputs(n, no_exceptions, true);
355       case aten::softmax:
356         return checkInputs(n, std::unordered_set<size_t>{1}, true);
357       case aten::multinomial:
358         return checkInputs(n, std::unordered_set<size_t>{2, 3}, false);
359       case aten::flatten:
360       case aten::argmax:
361       case aten::squeeze:
362       case aten::avg_pool2d:
363         return checkInputs(n, no_exceptions, false);
364       case aten::conv1d:
365       case aten::conv2d:
366       case aten::conv3d:
367         return checkInputs(n, std::unordered_set<size_t>{2, 6}, false);
368       case aten::slice:
369         return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
370             // check that the dimension argument is constant
371             n->input(1)->node()->kind() == prim::Constant &&
372             // the start offset is constant
373             n->input(2)->node()->kind() == prim::Constant &&
374             // the end offset is constant
375             n->input(3)->node()->kind() == prim::Constant &&
376             // the stride is constant
377             n->input(4)->node()->kind() == prim::Constant;
378       case aten::max_pool1d:
379       case aten::max_pool2d:
380       case aten::max_pool3d:
381         return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
382             // check that the kernel size is constant
383             n->input(1)->node()->kind() == prim::Constant &&
384             // check that the stride is constant
385             n->input(2)->node()->kind() == prim::Constant &&
386             // check that the padding is constant
387             n->input(3)->node()->kind() == prim::Constant &&
388             // check that the dilation is constant
389             n->input(4)->node()->kind() == prim::Constant &&
390             // check that the ceil_mode is constant
391             n->input(5)->node()->kind() == prim::Constant;
392       case aten::unsqueeze:
393         // check that the dimension argument is constant
394         return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
395             n->input(1)->node()->kind() == prim::Constant;
396       case aten::cat:
397         // check that the dimension argument is constant
398         return n->input(1)->node()->kind() == prim::Constant &&
399             n->input(0)->node()->kind() == prim::ListConstruct &&
400             // no extra nodes in between aten::cat and prim::ListConstruct
401             n->prev() == n->input(0)->node() &&
402             // check the inputs to prim::ListConstruct (not aten::cat)
403             checkInputs(n->input(0)->node(), no_exceptions, false);
404       case aten::clamp:
405         // the second and third args do not affect shapes
406         return checkInputs(n, std::unordered_set<size_t>{1, 2}, false);
407       // after some optimizations we might end up with two Guards back-to-back
408       // which case we can remove the one whose input is also prim::Guard
409       case aten::_grad_sum_to_size:
410         // skip checking size argument
411         if (checkInputs(n, std::unordered_set<size_t>{1}, false)) {
412           auto asize = n->input(1)->node();
413           if (asize->kind() == prim::Constant) {
414             return true;
415           } else if (asize->matches("aten::size(Tensor self) -> int[]")) {
416             // aten::size is effectively a constant
417             if (asize->input()
418                     ->type()
419                     ->expectRef<TensorType>()
420                     .sizes()
421                     .concrete_sizes()) {
422               return true;
423             }
424           }
425         }
426         return false;
427 
428       // this is checked by one of the tests in test_jit_fuser.py
429       case prim::ListUnpack: {
430         // check if the input is a constant chunk
431         // used for LSTM fusions
432         auto chunk = n->input(0)->node();
433         if (chunk->kind() != aten::chunk) {
434           return false;
435         }
436         return checkInputs(chunk, no_exceptions, false);
437       }
438       // this is checked by one of the tests in test_jit_fuser.py
439       case aten::broadcast_tensors: {
440         auto list_construct = n->input(0)->node();
441         if (list_construct->kind() != prim::ListConstruct) {
442           return false;
443         }
444         return checkInputs(list_construct, no_exceptions, false);
445       }
446       case prim::Guard:
447       case prim::GradOf:
448         return true;
449       default:
450         GRAPH_DEBUG("cannot remove ", n->kind().toQualString());
451         return false;
452     }
453   }
454 
455   std::shared_ptr<Graph> graph_;
456   std::unique_ptr<AliasDb> aliasDb_;
457   static std::unordered_set<Symbol> simple_ops_;
458 };
459 
EliminateRedundantGuards(std::shared_ptr<Graph> graph)460 void EliminateRedundantGuards(std::shared_ptr<Graph> graph) {
461   GuardElimination ge(std::move(graph));
462   ge.run();
463 }
464 
465 } // namespace torch::jit
466