1 #include <torch/csrc/jit/passes/dead_code_elimination.h>
2
3 #include <c10/util/irange.h>
4 #include <torch/csrc/jit/ir/alias_analysis.h>
5 #include <torch/csrc/jit/ir/ir_views.h>
6 #include <torch/csrc/jit/jit_log.h>
7
8 #include <unordered_map>
9
10 namespace torch::jit {
11
12 namespace prim {
13 using namespace ::c10::prim;
14 }
15
16 class DeadCodeEliminator {
17 public:
DeadCodeEliminator(std::shared_ptr<Graph> graph,DCESideEffectPolicy sideEffectPolicy)18 explicit DeadCodeEliminator(
19 std::shared_ptr<Graph> graph,
20 DCESideEffectPolicy sideEffectPolicy)
21 : sideEffectPolicy_(sideEffectPolicy),
22 graph_(std::move(graph)),
23 useAliasDb_(true) {}
DeadCodeEliminator(DCESideEffectPolicy sideEffectPolicy)24 DeadCodeEliminator(DCESideEffectPolicy sideEffectPolicy)
25 : sideEffectPolicy_(sideEffectPolicy) {}
26
27 // The algorithm is an inverse mark-and-sweep. Starting from the return node,
28 // we mark "live" nodes that are necessary for the output. Nodes that have
29 // side effects are also marked.
run(Block * block,bool recurse)30 void run(Block* block, bool recurse) {
31 // clean up unused fork inputs before starting the main algorithm
32 eliminateDeadForkInputs(block, recurse);
33
34 // Initialize by marking the return node and all its consumed values as live
35 mark(block->return_node());
36
37 mark(block);
38
39 deleteCallback_(liveValues_);
40
41 sweep(block, recurse);
42 }
43
setDeleteCallback(std::function<void (const std::unordered_set<const Value * > &)> deleteCallback)44 void setDeleteCallback(
45 std::function<void(const std::unordered_set<const Value*>&)>
46 deleteCallback) {
47 deleteCallback_ = std::move(deleteCallback);
48 }
49
50 private:
eliminateDeadForkInputs(Block * block,bool recurse)51 void eliminateDeadForkInputs(Block* block, bool recurse) {
52 for (Node* node : block->nodes()) {
53 if (recurse) {
54 for (Block* sb : node->blocks()) {
55 eliminateDeadForkInputs(sb, recurse);
56 }
57 }
58 if (node->kind() != prim::fork) {
59 continue;
60 }
61 Graph& g = *node->g(attr::Subgraph);
62 // WARNING: Do not use a ranged loop. The loop bounds are changed by the
63 // loop body.
64 for (size_t i = 0; i < g.inputs().size(); ++i) {
65 if (!g.inputs().at(i)->hasUses()) {
66 GRAPH_UPDATE(
67 "Dead ",
68 i,
69 "-th input ",
70 node->inputs().at(i)->debugName(),
71 "(",
72 g.inputs().at(i)->debugName(),
73 " in a subgraph) will be removed");
74 g.eraseInput(i);
75 node->removeInput(i);
76 }
77 }
78 }
79 }
80
81 // Special handling for block return nodes. Unlike other nodes, the block
82 // return node doesn't really "use" its inputs. Consider:
83 //
84 // %a0 = aten::foo()
85 // %b = aten::foo()
86 // %a2, %b2 = prim::If(%cond) {
87 // block0() {
88 // %a1 = aten::foo(%.0)
89 // %b1 = aten::foo(%b)
90 // } -> (%a1, %b1)
91 // }
92 // return (%a2)
93 //
94 // We want to be able to DCE all the %b stuff. So when processing block
95 // returns, we only mark producers for values that "live" (i.e. used outside
96 // the block).
97 //
98 // Returns true iff this marked something we haven't marked before.
markReturnNode(Node * node)99 bool markReturnNode(Node* node) {
100 if (marked_.count(node)) {
101 return false;
102 }
103
104 AT_ASSERT(node->owningBlock()->return_node() == node);
105 auto outerNode = node->owningBlock()->owningNode();
106 if (outerNode == nullptr || outerNode->kind() == prim::Reverse) {
107 // If there's no outer node, we're looking at the graph's top-level
108 // return block. We consider all graph outputs to be "used", so just mark
109 // this node normally.
110 return mark(node);
111 }
112
113 // Collect all inputs that are actually live
114 if (outerNode->kind() == prim::Loop ||
115 outerNode->kind() == c10::onnx::Loop) {
116 // Special handling to deal with loop carried dependencies.
117 auto loop = LoopView(outerNode);
118 for (const auto i : c10::irange(loop.carriedOutputs().size())) {
119 if (outerNode->kind() == c10::onnx::Loop) {
120 // Special handling for onnx loop.
121 // The number of body carried inputs and outputs are different.
122 // They cannot be mapped to each other easily by the same index.
123 liveValues_.insert(loop.bodyCarriedOutputs().at(i));
124 continue;
125 }
126 auto innerInput = loop.bodyCarriedInputs().at(i);
127 auto innerOutput = loop.bodyCarriedOutputs().at(i);
128 auto outerOutput = loop.carriedOutputs().at(i);
129 if (liveValues_.count(outerOutput) || innerInput->hasUses()) {
130 liveValues_.insert(innerOutput);
131 }
132 }
133
134 // Also mark the loop next condition as live, since it will be used inside
135 // the loop body.
136 liveValues_.insert(loop.nextCond());
137 } else {
138 AT_ASSERT(outerNode->outputs().size() == node->inputs().size());
139 for (const auto i : c10::irange(outerNode->outputs().size())) {
140 auto innerOutput = node->inputs()[i];
141 auto outerOutput = outerNode->outputs()[i];
142 if (liveValues_.count(outerOutput)) {
143 liveValues_.insert(innerOutput);
144 }
145 }
146 }
147
148 marked_.insert(node);
149 return true;
150 }
151
152 // Loops are special, because we need to run them to convergence.
153 // Consider the following loop:
154 // for i in range(3):
155 // tot += a[0][0]
156 // b = a[0]
157 // b[0] += 1
158 // print(tot)
159 //
160 // If we only process the loop block once, we will conclude that `b[0]` and
161 // `b` are dead, even though `b[0] += 1` mutates a live memory location (since
162 // `b[0]` is an alias of `a`). i.e. `a` is used to compute `tot` in the next
163 // iteration
164 //
165 // We need to mark the loop again with the information that `a` is live, and
166 // repeat until we're not marking new stuff anymore.
167 //
168 // Returns true iff this marked something we haven't marked before.
markLoop(Node * node)169 bool markLoop(Node* node) {
170 TORCH_INTERNAL_ASSERT(node->kind() == prim::Loop);
171 // Did a single iteration over the loop block mark anything new?
172 // If this is false, we've converged.
173 bool marked = false;
174 // Did we ever mark anything new?
175 bool anyMarked = false;
176 do {
177 marked = mark(node->blocks().at(0));
178 anyMarked |= marked;
179 } while (marked);
180 return anyMarked;
181 }
182
183 // Returns true iff this marked something we haven't marked before.
mark(Block * block)184 bool mark(Block* block) {
185 bool anyMarked = false;
186 // Mark all nodes with side effects.
187 for (auto node : block->nodes()) {
188 if (sideEffectPolicy_ ==
189 DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS &&
190 hasSideEffects(node)) {
191 anyMarked |= mark(node);
192 }
193 }
194
195 // Initialize by marking the return node
196 anyMarked |= markReturnNode(block->return_node());
197
198 for (auto it = block->nodes().rbegin(); it != block->nodes().rend(); ++it) {
199 auto node = *it;
200 if (node->kind() == prim::Loop) {
201 // Special casing for loops, see comment in markLoop.
202 anyMarked |= markLoop(node);
203 } else {
204 // Other nodes with sub-blocks get marked normally.
205 for (auto subBlock : node->blocks()) {
206 anyMarked |= mark(subBlock);
207 }
208 }
209 anyMarked |= markIfLive(node);
210 }
211 return anyMarked;
212 }
213
214 // If we output or write to a live memory location, mark this node
215 // Returns true iff this marked something we haven't marked before.
markIfLive(Node * node)216 bool markIfLive(Node* node) {
217 for (const auto output : node->outputs()) {
218 if (liveValues_.count(output)) {
219 return mark(node);
220 }
221 }
222
223 if (useAliasDb_) {
224 if (getOrCreateAliasDb()->writesToAlias(node, liveValues_)) {
225 return mark(node);
226 }
227 }
228
229 return false;
230 }
231
232 // Mark this node as live and add this node's inputs and aliases to the live
233 // value sets.
234 // Returns true iff this marked something we haven't marked before.
mark(Node * node)235 bool mark(Node* node) {
236 if (marked_.count(node)) {
237 return false;
238 }
239
240 marked_.insert(node);
241
242 // Mark all nodes in this node's blockchain (since owning nodes are
243 // considered live if they contain a live node)
244 auto curNode = node;
245 while (curNode) {
246 if (!curNode->owningBlock()) {
247 break;
248 }
249
250 mark(curNode);
251 curNode = curNode->owningBlock()->owningNode();
252 }
253
254 for (const auto input : node->inputs()) {
255 if (liveValues_.count(input)) {
256 continue;
257 }
258 liveValues_.insert(input);
259 }
260 return true;
261 }
262
263 // Delete all unmarked nodes.
sweep(Block * block,bool recurse)264 void sweep(Block* block, bool recurse) {
265 auto nodes = block->nodes().reverse();
266 for (auto it = nodes.begin(); it != nodes.end(); it++) {
267 auto node = *it;
268 // note these occur before the recursion because we want to uncover
269 // dead code in the blocks used to calculate the output
270 removeDeadBlockOutputs(node);
271 removeDeadLoopOutputs(node);
272 if (recurse) {
273 for (Block* block : node->blocks()) {
274 sweep(block, true);
275 }
276 }
277 // NB: Checking hasUses() is required. AD graphs are not perfectly
278 // valid, as a node in grad_desc.f might be used in reverse_block.
279 // Reverse_block is inlined in grad_desc.f before it's separated
280 // to grad_desc.df.
281 if (!(marked_.count(node) || node->hasUses())) {
282 GRAPH_UPDATE(
283 "Node ",
284 it->kind().toQualString(),
285 " which outputs ",
286 (!node->outputs().empty() ? node->outputs().at(0)->debugName()
287 : "n/a"),
288 " will be removed");
289 it.destroyCurrent();
290 }
291 }
292 }
293
hasUntrackedMutation(Node * node)294 bool hasUntrackedMutation(Node* node) {
295 if (!useAliasDb_) {
296 // If we don't have alias information, all mutable ops have unknown
297 // effects and can't be considered for elimination.
298
299 if (node->kind() == prim::SetAttr) {
300 // SetAttr is a special case: it doesn't have a schema, but does
301 // have untracked mutations
302 return true;
303 }
304
305 // onnx export calls EliminateDeadCode but sometimes passes invalid
306 // aten operators. So we call maybeSchema so we handle the cases when
307 // there is no valid schema for a node
308 auto schema = node->maybeSchema();
309 return schema && schema->is_mutable();
310 } else {
311 return getOrCreateAliasDb()->writesToWildcard(node);
312 }
313 }
314
hasSideEffects(Node * node)315 bool hasSideEffects(Node* node) {
316 auto it = memo_.find(node);
317 if (it != memo_.end())
318 return it->second;
319 bool has_side_effects = node->hasSideEffects() ||
320 std::any_of(node->blocks().begin(),
321 node->blocks().end(),
322 [&](Block* b) {
323 return std::any_of(
324 b->nodes().begin(), b->nodes().end(), [&](Node* n) {
325 return hasSideEffects(n);
326 });
327 }) ||
328 hasUntrackedMutation(node);
329
330 memo_.emplace(node, has_side_effects);
331 return has_side_effects;
332 }
333
removeDeadBlockOutputs(Node * node)334 void removeDeadBlockOutputs(Node* node) {
335 if (node->kind() != prim::If && node->kind() != prim::GradOf) {
336 return;
337 }
338
339 for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
340 size_t i = i_1 - 1;
341 if (!node->outputs().at(i)->hasUses()) {
342 GRAPH_UPDATE(
343 "Dead ",
344 i,
345 "-th output ",
346 node->outputs().at(i)->debugName(),
347 " of node ",
348 node->kind().toQualString(),
349 " will be removed");
350 node->eraseOutput(i);
351 for (Block* b : node->blocks()) {
352 GRAPH_UPDATE(
353 "\tCorresponding block output ",
354 b->outputs().at(i)->debugName(),
355 " will be removed");
356 b->eraseOutput(i);
357 }
358 }
359 }
360 }
361
removeDeadLoopOutputs(Node * node)362 void removeDeadLoopOutputs(Node* node) {
363 if (node->kind() != prim::Loop)
364 return;
365 auto loop_body = node->blocks().at(0);
366 auto loop_input_offset = 2; // offset of loop carried deps in input list
367 auto loop_body_offset =
368 1; // offset to the loop carried dependencies in block inputs/outputs
369
370 for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
371 size_t i = i_1 - 1;
372 if (!node->outputs().at(i)->hasUses() &&
373 !loop_body->inputs().at(loop_body_offset + i)->hasUses()) {
374 logDeadLoopOutputs(node, i, loop_input_offset, loop_body_offset);
375 node->eraseOutput(i);
376 node->removeInput(loop_input_offset + i);
377 loop_body->eraseInput(loop_body_offset + i);
378 loop_body->eraseOutput(loop_body_offset + i);
379 }
380 }
381 }
382
logDeadLoopOutputs(Node * node,size_t i,size_t loop_input_offset,size_t loop_body_offset)383 void logDeadLoopOutputs(
384 Node* node,
385 size_t i,
386 size_t loop_input_offset,
387 size_t loop_body_offset) {
388 auto loop_body = node->blocks().at(0);
389 GRAPH_UPDATE(
390 "Dead ",
391 loop_input_offset + i,
392 "-th input ",
393 node->inputs().at(i)->debugName(),
394 " will be removed");
395 GRAPH_UPDATE(
396 "Dead ",
397 i,
398 "-th output ",
399 node->outputs().at(i)->debugName(),
400 " will be removed");
401 GRAPH_UPDATE(
402 "\tDead block input ",
403 loop_body->inputs().at(loop_body_offset + i)->debugName(),
404 "at offset ",
405 loop_body_offset + i,
406 " will be removed");
407 GRAPH_UPDATE(
408 "\tDead block output ",
409 loop_body->outputs().at(loop_body_offset + i)->debugName(),
410 "at offset ",
411 loop_body_offset + i,
412 " will be removed");
413 }
414
getOrCreateAliasDb()415 AliasDb* getOrCreateAliasDb() {
416 if (!aliasDb_) {
417 aliasDb_ = std::make_unique<AliasDb>(graph_);
418 }
419 return aliasDb_.get();
420 }
421
422 DCESideEffectPolicy sideEffectPolicy_;
423
424 std::shared_ptr<Graph> graph_;
425 bool useAliasDb_ = false;
426 // lazily initialized
427 std::unique_ptr<AliasDb> aliasDb_ = nullptr;
428 std::unordered_map<Node*, bool> memo_;
429 std::unordered_set<Node*> marked_;
430 std::unordered_set<const Value*> liveValues_;
431 std::function<void(const std::unordered_set<const Value*>&)> deleteCallback_ =
__anonb8bf968d0302(const std::unordered_set<const Value*>&) 432 [](const std::unordered_set<const Value*>&) {};
433 };
434
EliminateDeadCode(const std::shared_ptr<Graph> & graph,DCESideEffectPolicy sideEffectPolicy)435 void EliminateDeadCode(
436 const std::shared_ptr<Graph>& graph,
437 DCESideEffectPolicy sideEffectPolicy) {
438 DeadCodeEliminator(graph, sideEffectPolicy)
439 .run(graph->block(), /*recurse=*/true);
440 GRAPH_DUMP("After EliminateDeadCode: ", graph);
441 }
442
EliminateDeadCode(Block * block,bool recurse,DCESideEffectPolicy sideEffectPolicy)443 void EliminateDeadCode(
444 Block* block,
445 bool recurse,
446 DCESideEffectPolicy sideEffectPolicy) {
447 DeadCodeEliminator(sideEffectPolicy).run(block, recurse);
448 }
449
EliminateDeadCode(Block * block,std::function<void (const std::unordered_set<const Value * > &)> cb,DCESideEffectPolicy sideEffectPolicy)450 void EliminateDeadCode(
451 Block* block,
452 std::function<void(const std::unordered_set<const Value*>&)> cb,
453 DCESideEffectPolicy sideEffectPolicy) {
454 DeadCodeEliminator eliminator(sideEffectPolicy);
455 eliminator.setDeleteCallback(std::move(cb));
456 eliminator.run(block, /*recurse=*/true);
457 }
458
459 } // namespace torch::jit
460