1 #include <torch/csrc/jit/passes/guard_elimination.h>
2
3 #include <torch/csrc/jit/ir/alias_analysis.h>
4 #include <torch/csrc/jit/jit_log.h>
5 #include <torch/csrc/jit/passes/constant_propagation.h>
6 #include <torch/csrc/jit/passes/peephole.h>
7 #include <torch/csrc/jit/runtime/graph_executor.h>
8 #include <memory>
9 #include <unordered_set>
10
11 namespace torch::jit {
12
13 struct GuardElimination {
GuardEliminationtorch::jit::GuardElimination14 GuardElimination(std::shared_ptr<Graph> graph)
15 : graph_(std::move(graph)), aliasDb_(std::make_unique<AliasDb>(graph_)) {}
16
runtorch::jit::GuardElimination17 void run() {
18 const size_t MAX_ATTEMPTS = 5;
19 size_t attempts = MAX_ATTEMPTS;
20 while (attempts-- && moveGuardsToDefs(graph_->block())) {
21 }
22 GRAPH_DUMP("After moveGuardsToDefs", graph_);
23 coalesceGuards(graph_->block());
24 GRAPH_DUMP("After coalesceGuards", graph_);
25 removeDominatedGuards(graph_->block());
26 GRAPH_DUMP("After removeDominatedGuards", graph_);
27 eliminateRedundantGuards(graph_->block());
28 GRAPH_DUMP("After eliminateRedundantGuards", graph_);
29 }
30
isLoweredGradOftorch::jit::GuardElimination31 static bool isLoweredGradOf(Node* n) {
32 if (n->kind() != prim::If) {
33 return false;
34 }
35
36 return n->input(0)->node()->kind() == prim::AutogradAnyNonZero;
37 }
38
moveGuardsToDefstorch::jit::GuardElimination39 bool moveGuardsToDefs(Block* b) {
40 bool changed = false;
41 for (auto it = b->nodes().begin(); it != b->nodes().end();) {
42 auto n = *it;
43 if (n->kind() == prim::Guard) {
44 // grab the next node before we move this one all the way back
45 it++;
46 auto guardee = n->inputs().at(0)->node();
47 // alias analysis will try to hoist a node out of a loop
48 // if asked. if guardee is in a loop, it should only
49 // be moved to the beginning of the basic block
50 // given the current implementation of AliasAnalysis
51 if (guardee->owningBlock() != n->owningBlock()) {
52 guardee = *n->owningBlock()->nodes().begin();
53 }
54 bool moved = aliasDb_->moveAfterTopologicallyValid(n, guardee);
55 changed |= moved;
56 if (moved) {
57 GRAPH_UPDATE(
58 "Moved ",
59 n->output()->debugName(),
60 " to ",
61 n->inputs().at(0)->debugName());
62 }
63 } else {
64 it++;
65 for (Block* ib : n->blocks()) {
66 moveGuardsToDefs(ib);
67 }
68 }
69 }
70
71 if (b->owningNode() &&
72 isLoweredGradOf(
73 b->owningNode()) /*b->owningNode()->kind() == prim::If*/) {
74 for (auto it = b->nodes().begin(); it != b->nodes().end();) {
75 auto block_node = *it++;
76 if (block_node->kind() != prim::Guard) {
77 break;
78 }
79 block_node->moveBefore(b->owningNode());
80 changed = true;
81 }
82 }
83
84 return changed;
85 }
86
coalesceGuardstorch::jit::GuardElimination87 void coalesceGuards(Block* b) {
88 // uses on *all* parameters are moved to the same anchor node
89 // and they may come in different order after the anchor node
90 // e.g. (anchor, guard_x, guard_y, guard_x, guard_y)
91 // this pass recognizes contiguous stretches of guards and
92 // keeps track of the guards it's seen for each def. the next time
93 // the guard on the same def, it simply removes it.
94 std::unordered_map<Value*, Node*> inputs_to_guards;
95 for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
96 auto n = *it;
97 if (n->kind() == prim::Guard) {
98 if (inputs_to_guards.count(n->input())) {
99 auto prev = inputs_to_guards[n->input()];
100 n->output()->replaceAllUsesWith(prev->output());
101 GRAPH_UPDATE(
102 "Replacing ",
103 n->output()->debugName(),
104 " with ",
105 prev->output()->debugName());
106 it.destroyCurrent();
107 } else {
108 inputs_to_guards.insert({n->input(), n});
109 }
110 } else if (n->kind() != prim::Constant) {
111 inputs_to_guards.clear();
112 for (Block* ib : n->blocks()) {
113 coalesceGuards(ib);
114 }
115 }
116 }
117 }
118
removeDominatedGuardstorch::jit::GuardElimination119 void removeDominatedGuards(Block* b) {
120 // If a Node guards a value which isn't mutated, then that node
121 // can replace all other guards of the value which it dominates
122 for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
123 auto n = *it;
124 if (n->kind() == prim::Guard) {
125 Value* input = n->input();
126 if (aliasDb_->hasWriters(input)) {
127 continue;
128 }
129 Value* guard_output = n->output();
130
131 // find all uses of the input that the guard node dominates
132 std::vector<Use> uses = input->uses();
133 while (!uses.empty()) {
134 auto use = uses.at(uses.size() - 1);
135 uses.pop_back();
136
137 // not all uses are guarded
138 if (use.user->kind() != prim::Guard) {
139 continue;
140 }
141
142 if (!use.user->isDominatedBy(n)) {
143 continue;
144 }
145
146 // the dominated guard type may be different from the dominator
147 // if it is only executed for a subtype, or if it is executed
148 // in a different global context for grad enabled
149 // check that the types are equal before continuing
150
151 auto dominator_type = guard_output->type();
152 auto dominated_type = use.user->output()->type();
153
154 if (*dominator_type == *dominated_type) {
155 use.user->replaceInput(use.offset, guard_output);
156 }
157 }
158
159 // remove redundant dominated guards
160 std::vector<Use> users = n->output()->uses();
161 for (auto use : users) {
162 auto user = use.user;
163 if (user->kind() == prim::Guard) {
164 GRAPH_UPDATE(
165 "Removing dominated guard ", user, " and replacing with ", n);
166 user->output()->replaceAllUsesWith(guard_output);
167 user->destroy();
168 }
169 }
170 } else {
171 for (Block* ib : n->blocks()) {
172 removeDominatedGuards(ib);
173 }
174 }
175 }
176 }
177
178 // we need to make sure there are no ops in between guardee's
179 // output and its guard except for other guards as they can
180 // invalidate shape information.
guardsOutputtorch::jit::GuardElimination181 bool guardsOutput(Node* guard) {
182 auto output = guard->input()->node();
183 auto it = guard;
184 while (it != output) {
185 if (it->kind() != prim::Guard && it->kind() != prim::Constant) {
186 GRAPH_DEBUG(
187 "found an unexpected node ",
188 *it,
189 " while trying to eliminate ",
190 *guard);
191 return false;
192 }
193 it = it->prev();
194 }
195
196 return true;
197 }
198
eliminateRedundantGuardstorch::jit::GuardElimination199 void eliminateRedundantGuards(Block* b) {
200 // a very simple pass to eliminate redundant guards for ops
201 // whose outputs are fully determined by their inputs
202 // i.e. if inputs to such ops are guarded we are allowed
203 // to remove a guard on ops' outputs
204 for (auto it = b->nodes().rbegin(); it != b->nodes().rend();) {
205 auto n = *it;
206 if (n->kind() == prim::Guard && guardsOutput(n) &&
207 removableGuard(n->inputs().at(0)->node())) {
208 auto pttp = n->output()->type();
209 n->output()->replaceAllUsesWith(n->inputs().at(0));
210 n->inputs().at(0)->setType(pttp);
211 GRAPH_UPDATE(
212 "Eliminating the redundant guard ", n->output()->debugName());
213 it.destroyCurrent();
214 } else {
215 it++;
216 for (Block* ib : n->blocks()) {
217 eliminateRedundantGuards(ib);
218 }
219 }
220 }
221 }
222
223 // `checkInputs` check the invariants specified in `removableGuard`
224 // on inputs to `n`. The invariants must hold, or an input must
225 // be a `prim::Constant` or be included as an exception in `except`
checkInputstorch::jit::GuardElimination226 bool checkInputs(
227 Node* n,
228 const std::unordered_set<size_t>& except,
229 bool allow_numbers) {
230 bool all_inputs_guarded = true;
231 size_t i = 0;
232 for (auto input : n->inputs()) {
233 if ((input->node()->kind() == prim::Guard &&
234 !input->type()->expectRef<TensorType>().isSummarized()) ||
235 input->node()->kind() == prim::Constant ||
236 (allow_numbers && input->type()->isSubtypeOf(*NumberType::get())) ||
237 except.count(i) != 0) {
238 AT_ASSERT(
239 input->node()->kind() != prim::Guard ||
240 input->type()->expect<TensorType>());
241 } else {
242 GRAPH_DEBUG(
243 "input ",
244 input->debugName(),
245 " isn't guarded, type ",
246 *input->type());
247 all_inputs_guarded = false;
248 break;
249 }
250 i++;
251 }
252 return all_inputs_guarded;
253 }
254
255 private:
256 // `removableGuard` relies on the properties checked by `isSummarized()`
257 // and passes shouldn't insert nodes between a guard and its uses that
258 // may alter those properties.
259 // `removableGuard` expects type information to come directly from
260 // Profiler. Passes shouldn't try to alter type information provided by
261 // profiling
262 // While we can derive very simple rules stating when it's valid to remove
263 // `prim::Guard` on operation's output if all of its inputs are guarded for
264 // some
265 // categories of operations
266 // there's no comprehensive set of rules that covers all the operations
267 // available in PyTorch
268 // If your operation falls into one of the categories described below, you
269 // should add it
270 // to switch statement below that contains the other operations in the said
271 // category.
272 // Otherwise, you will need to derive the rules for your case on your own.
273 // Generally, any operation that is stateful in any way or uses its underlying
274 // data
275 // to compute any properties `isSummarized()` isn't amenable to guard
276 // elimination.
277 // Categories:
278 // * Functional-like(e.g. add, sub, le) operations with broadcast semenatics
279 // Guards can be removed if all inputs are guarded and `isSummarized()`
280 // returns
281 // false or inputs are `prim::Constant`
removableGuardtorch::jit::GuardElimination282 bool removableGuard(Node* n) {
283 const static auto no_exceptions = std::unordered_set<size_t>{};
284 switch (n->kind()) {
285 case aten::add:
286 case aten::add_:
287 case aten::sub:
288 case aten::mul:
289 case aten::div:
290 case aten::t:
291 case aten::sigmoid:
292 case aten::sin:
293 case aten::cos:
294 case aten::tan:
295 case aten::sinh:
296 case aten::cosh:
297 case aten::tanh:
298 case aten::asin:
299 case aten::acos:
300 case aten::atan:
301 case aten::atan2:
302 case aten::floor:
303 case aten::fmod:
304 case aten::ceil:
305 case aten::trunc:
306 case aten::sqrt:
307 case aten::rsqrt:
308 case aten::remainder:
309 case aten::mm:
310 case aten::min:
311 case aten::max:
312 case aten::type_as:
313 case aten::ge:
314 case aten::gt:
315 case aten::lt:
316 case aten::le:
317 case aten::eq:
318 case aten::ne:
319 case aten::neg:
320 case prim::ConstantChunk:
321 case aten::size:
322 case aten::abs:
323 case aten::sign:
324 case aten::pow:
325 case aten::relu:
326 case aten::threshold:
327 case prim::AutogradAdd:
328 case prim::AutogradZero:
329 case aten::rand_like:
330 case aten::erf:
331 case aten::erfc:
332 case aten::exp:
333 case aten::expm1:
334 case aten::log:
335 case aten::log2:
336 case aten::log10:
337 case aten::frac:
338 case aten::lerp:
339 case aten::lgamma:
340 case aten::reciprocal:
341 case aten::addcmul:
342 case aten::where:
343 case aten::_cast_Float:
344 case aten::_cast_Long:
345 case aten::__and__:
346 case aten::__or__:
347 case aten::__xor__:
348 case aten::__lshift__:
349 case aten::__rshift__:
350 case aten::bitwise_not:
351 case aten::bitwise_and:
352 case aten::bitwise_or:
353 case aten::bitwise_xor:
354 return checkInputs(n, no_exceptions, true);
355 case aten::softmax:
356 return checkInputs(n, std::unordered_set<size_t>{1}, true);
357 case aten::multinomial:
358 return checkInputs(n, std::unordered_set<size_t>{2, 3}, false);
359 case aten::flatten:
360 case aten::argmax:
361 case aten::squeeze:
362 case aten::avg_pool2d:
363 return checkInputs(n, no_exceptions, false);
364 case aten::conv1d:
365 case aten::conv2d:
366 case aten::conv3d:
367 return checkInputs(n, std::unordered_set<size_t>{2, 6}, false);
368 case aten::slice:
369 return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
370 // check that the dimension argument is constant
371 n->input(1)->node()->kind() == prim::Constant &&
372 // the start offset is constant
373 n->input(2)->node()->kind() == prim::Constant &&
374 // the end offset is constant
375 n->input(3)->node()->kind() == prim::Constant &&
376 // the stride is constant
377 n->input(4)->node()->kind() == prim::Constant;
378 case aten::max_pool1d:
379 case aten::max_pool2d:
380 case aten::max_pool3d:
381 return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
382 // check that the kernel size is constant
383 n->input(1)->node()->kind() == prim::Constant &&
384 // check that the stride is constant
385 n->input(2)->node()->kind() == prim::Constant &&
386 // check that the padding is constant
387 n->input(3)->node()->kind() == prim::Constant &&
388 // check that the dilation is constant
389 n->input(4)->node()->kind() == prim::Constant &&
390 // check that the ceil_mode is constant
391 n->input(5)->node()->kind() == prim::Constant;
392 case aten::unsqueeze:
393 // check that the dimension argument is constant
394 return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
395 n->input(1)->node()->kind() == prim::Constant;
396 case aten::cat:
397 // check that the dimension argument is constant
398 return n->input(1)->node()->kind() == prim::Constant &&
399 n->input(0)->node()->kind() == prim::ListConstruct &&
400 // no extra nodes in between aten::cat and prim::ListConstruct
401 n->prev() == n->input(0)->node() &&
402 // check the inputs to prim::ListConstruct (not aten::cat)
403 checkInputs(n->input(0)->node(), no_exceptions, false);
404 case aten::clamp:
405 // the second and third args do not affect shapes
406 return checkInputs(n, std::unordered_set<size_t>{1, 2}, false);
407 // after some optimizations we might end up with two Guards back-to-back
408 // which case we can remove the one whose input is also prim::Guard
409 case aten::_grad_sum_to_size:
410 // skip checking size argument
411 if (checkInputs(n, std::unordered_set<size_t>{1}, false)) {
412 auto asize = n->input(1)->node();
413 if (asize->kind() == prim::Constant) {
414 return true;
415 } else if (asize->matches("aten::size(Tensor self) -> int[]")) {
416 // aten::size is effectively a constant
417 if (asize->input()
418 ->type()
419 ->expectRef<TensorType>()
420 .sizes()
421 .concrete_sizes()) {
422 return true;
423 }
424 }
425 }
426 return false;
427
428 // this is checked by one of the tests in test_jit_fuser.py
429 case prim::ListUnpack: {
430 // check if the input is a constant chunk
431 // used for LSTM fusions
432 auto chunk = n->input(0)->node();
433 if (chunk->kind() != aten::chunk) {
434 return false;
435 }
436 return checkInputs(chunk, no_exceptions, false);
437 }
438 // this is checked by one of the tests in test_jit_fuser.py
439 case aten::broadcast_tensors: {
440 auto list_construct = n->input(0)->node();
441 if (list_construct->kind() != prim::ListConstruct) {
442 return false;
443 }
444 return checkInputs(list_construct, no_exceptions, false);
445 }
446 case prim::Guard:
447 case prim::GradOf:
448 return true;
449 default:
450 GRAPH_DEBUG("cannot remove ", n->kind().toQualString());
451 return false;
452 }
453 }
454
455 std::shared_ptr<Graph> graph_;
456 std::unique_ptr<AliasDb> aliasDb_;
457 static std::unordered_set<Symbol> simple_ops_;
458 };
459
EliminateRedundantGuards(std::shared_ptr<Graph> graph)460 void EliminateRedundantGuards(std::shared_ptr<Graph> graph) {
461 GuardElimination ge(std::move(graph));
462 ge.run();
463 }
464
465 } // namespace torch::jit
466