xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/dead_code_elimination.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/dead_code_elimination.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/csrc/jit/ir/alias_analysis.h>
5 #include <torch/csrc/jit/ir/ir_views.h>
6 #include <torch/csrc/jit/jit_log.h>
7 
8 #include <unordered_map>
9 
10 namespace torch::jit {
11 
12 namespace prim {
13 using namespace ::c10::prim;
14 }
15 
16 class DeadCodeEliminator {
17  public:
DeadCodeEliminator(std::shared_ptr<Graph> graph,DCESideEffectPolicy sideEffectPolicy)18   explicit DeadCodeEliminator(
19       std::shared_ptr<Graph> graph,
20       DCESideEffectPolicy sideEffectPolicy)
21       : sideEffectPolicy_(sideEffectPolicy),
22         graph_(std::move(graph)),
23         useAliasDb_(true) {}
DeadCodeEliminator(DCESideEffectPolicy sideEffectPolicy)24   DeadCodeEliminator(DCESideEffectPolicy sideEffectPolicy)
25       : sideEffectPolicy_(sideEffectPolicy) {}
26 
27   // The algorithm is an inverse mark-and-sweep. Starting from the return node,
28   // we mark "live" nodes that are necessary for the output. Nodes that have
29   // side effects are also marked.
run(Block * block,bool recurse)30   void run(Block* block, bool recurse) {
31     // clean up unused fork inputs before starting the main algorithm
32     eliminateDeadForkInputs(block, recurse);
33 
34     // Initialize by marking the return node and all its consumed values as live
35     mark(block->return_node());
36 
37     mark(block);
38 
39     deleteCallback_(liveValues_);
40 
41     sweep(block, recurse);
42   }
43 
setDeleteCallback(std::function<void (const std::unordered_set<const Value * > &)> deleteCallback)44   void setDeleteCallback(
45       std::function<void(const std::unordered_set<const Value*>&)>
46           deleteCallback) {
47     deleteCallback_ = std::move(deleteCallback);
48   }
49 
50  private:
eliminateDeadForkInputs(Block * block,bool recurse)51   void eliminateDeadForkInputs(Block* block, bool recurse) {
52     for (Node* node : block->nodes()) {
53       if (recurse) {
54         for (Block* sb : node->blocks()) {
55           eliminateDeadForkInputs(sb, recurse);
56         }
57       }
58       if (node->kind() != prim::fork) {
59         continue;
60       }
61       Graph& g = *node->g(attr::Subgraph);
62       // WARNING: Do not use a ranged loop. The loop bounds are changed by the
63       // loop body.
64       for (size_t i = 0; i < g.inputs().size(); ++i) {
65         if (!g.inputs().at(i)->hasUses()) {
66           GRAPH_UPDATE(
67               "Dead ",
68               i,
69               "-th input ",
70               node->inputs().at(i)->debugName(),
71               "(",
72               g.inputs().at(i)->debugName(),
73               " in a subgraph) will be removed");
74           g.eraseInput(i);
75           node->removeInput(i);
76         }
77       }
78     }
79   }
80 
81   // Special handling for block return nodes. Unlike other nodes, the block
82   // return node doesn't really "use" its inputs. Consider:
83   //
84   // %a0 = aten::foo()
85   // %b = aten::foo()
86   // %a2, %b2 = prim::If(%cond) {
87   //   block0() {
88   //     %a1 = aten::foo(%.0)
89   //     %b1 = aten::foo(%b)
90   //   } -> (%a1, %b1)
91   // }
92   // return (%a2)
93   //
94   // We want to be able to DCE all the %b stuff. So when processing block
95   // returns, we only mark producers for values that "live" (i.e. used outside
96   // the block).
97   //
98   // Returns true iff this marked something we haven't marked before.
markReturnNode(Node * node)99   bool markReturnNode(Node* node) {
100     if (marked_.count(node)) {
101       return false;
102     }
103 
104     AT_ASSERT(node->owningBlock()->return_node() == node);
105     auto outerNode = node->owningBlock()->owningNode();
106     if (outerNode == nullptr || outerNode->kind() == prim::Reverse) {
107       // If there's no outer node, we're looking at the graph's top-level
108       // return block. We consider all graph outputs to be "used", so just mark
109       // this node normally.
110       return mark(node);
111     }
112 
113     // Collect all inputs that are actually live
114     if (outerNode->kind() == prim::Loop ||
115         outerNode->kind() == c10::onnx::Loop) {
116       // Special handling to deal with loop carried dependencies.
117       auto loop = LoopView(outerNode);
118       for (const auto i : c10::irange(loop.carriedOutputs().size())) {
119         if (outerNode->kind() == c10::onnx::Loop) {
120           // Special handling for onnx loop.
121           // The number of body carried inputs and outputs are different.
122           // They cannot be mapped to each other easily by the same index.
123           liveValues_.insert(loop.bodyCarriedOutputs().at(i));
124           continue;
125         }
126         auto innerInput = loop.bodyCarriedInputs().at(i);
127         auto innerOutput = loop.bodyCarriedOutputs().at(i);
128         auto outerOutput = loop.carriedOutputs().at(i);
129         if (liveValues_.count(outerOutput) || innerInput->hasUses()) {
130           liveValues_.insert(innerOutput);
131         }
132       }
133 
134       // Also mark the loop next condition as live, since it will be used inside
135       // the loop body.
136       liveValues_.insert(loop.nextCond());
137     } else {
138       AT_ASSERT(outerNode->outputs().size() == node->inputs().size());
139       for (const auto i : c10::irange(outerNode->outputs().size())) {
140         auto innerOutput = node->inputs()[i];
141         auto outerOutput = outerNode->outputs()[i];
142         if (liveValues_.count(outerOutput)) {
143           liveValues_.insert(innerOutput);
144         }
145       }
146     }
147 
148     marked_.insert(node);
149     return true;
150   }
151 
152   // Loops are special, because we need to run them to convergence.
153   // Consider the following loop:
154   //   for i in range(3):
155   //     tot += a[0][0]
156   //     b = a[0]
157   //     b[0] += 1
158   //   print(tot)
159   //
160   // If we only process the loop block once, we will conclude that `b[0]` and
161   // `b` are dead, even though `b[0] += 1` mutates a live memory location (since
162   // `b[0]` is an alias of `a`). i.e. `a` is used to compute `tot` in the next
163   // iteration
164   //
165   // We need to mark the loop again with the information that `a` is live, and
166   // repeat until we're not marking new stuff anymore.
167   //
168   // Returns true iff this marked something we haven't marked before.
markLoop(Node * node)169   bool markLoop(Node* node) {
170     TORCH_INTERNAL_ASSERT(node->kind() == prim::Loop);
171     // Did a single iteration over the loop block mark anything new?
172     // If this is false, we've converged.
173     bool marked = false;
174     // Did we ever mark anything new?
175     bool anyMarked = false;
176     do {
177       marked = mark(node->blocks().at(0));
178       anyMarked |= marked;
179     } while (marked);
180     return anyMarked;
181   }
182 
183   // Returns true iff this marked something we haven't marked before.
mark(Block * block)184   bool mark(Block* block) {
185     bool anyMarked = false;
186     // Mark all nodes with side effects.
187     for (auto node : block->nodes()) {
188       if (sideEffectPolicy_ ==
189               DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS &&
190           hasSideEffects(node)) {
191         anyMarked |= mark(node);
192       }
193     }
194 
195     // Initialize by marking the return node
196     anyMarked |= markReturnNode(block->return_node());
197 
198     for (auto it = block->nodes().rbegin(); it != block->nodes().rend(); ++it) {
199       auto node = *it;
200       if (node->kind() == prim::Loop) {
201         // Special casing for loops, see comment in markLoop.
202         anyMarked |= markLoop(node);
203       } else {
204         // Other nodes with sub-blocks get marked normally.
205         for (auto subBlock : node->blocks()) {
206           anyMarked |= mark(subBlock);
207         }
208       }
209       anyMarked |= markIfLive(node);
210     }
211     return anyMarked;
212   }
213 
214   // If we output or write to a live memory location, mark this node
215   // Returns true iff this marked something we haven't marked before.
markIfLive(Node * node)216   bool markIfLive(Node* node) {
217     for (const auto output : node->outputs()) {
218       if (liveValues_.count(output)) {
219         return mark(node);
220       }
221     }
222 
223     if (useAliasDb_) {
224       if (getOrCreateAliasDb()->writesToAlias(node, liveValues_)) {
225         return mark(node);
226       }
227     }
228 
229     return false;
230   }
231 
232   // Mark this node as live and add this node's inputs and aliases to the live
233   // value sets.
234   // Returns true iff this marked something we haven't marked before.
mark(Node * node)235   bool mark(Node* node) {
236     if (marked_.count(node)) {
237       return false;
238     }
239 
240     marked_.insert(node);
241 
242     // Mark all nodes in this node's blockchain (since owning nodes are
243     // considered live if they contain a live node)
244     auto curNode = node;
245     while (curNode) {
246       if (!curNode->owningBlock()) {
247         break;
248       }
249 
250       mark(curNode);
251       curNode = curNode->owningBlock()->owningNode();
252     }
253 
254     for (const auto input : node->inputs()) {
255       if (liveValues_.count(input)) {
256         continue;
257       }
258       liveValues_.insert(input);
259     }
260     return true;
261   }
262 
263   // Delete all unmarked nodes.
sweep(Block * block,bool recurse)264   void sweep(Block* block, bool recurse) {
265     auto nodes = block->nodes().reverse();
266     for (auto it = nodes.begin(); it != nodes.end(); it++) {
267       auto node = *it;
268       // note these occur before the recursion because we want to uncover
269       // dead code in the blocks used to calculate the output
270       removeDeadBlockOutputs(node);
271       removeDeadLoopOutputs(node);
272       if (recurse) {
273         for (Block* block : node->blocks()) {
274           sweep(block, true);
275         }
276       }
277       // NB: Checking hasUses() is required. AD graphs are not perfectly
278       // valid, as a node in grad_desc.f might be used in reverse_block.
279       // Reverse_block is inlined in grad_desc.f before it's separated
280       // to grad_desc.df.
281       if (!(marked_.count(node) || node->hasUses())) {
282         GRAPH_UPDATE(
283             "Node ",
284             it->kind().toQualString(),
285             " which outputs ",
286             (!node->outputs().empty() ? node->outputs().at(0)->debugName()
287                                       : "n/a"),
288             " will be removed");
289         it.destroyCurrent();
290       }
291     }
292   }
293 
hasUntrackedMutation(Node * node)294   bool hasUntrackedMutation(Node* node) {
295     if (!useAliasDb_) {
296       // If we don't have alias information, all mutable ops have unknown
297       // effects and can't be considered for elimination.
298 
299       if (node->kind() == prim::SetAttr) {
300         // SetAttr is a special case: it doesn't have a schema, but does
301         // have untracked mutations
302         return true;
303       }
304 
305       // onnx export calls EliminateDeadCode but sometimes passes invalid
306       // aten operators. So we call maybeSchema so we handle the cases when
307       // there is no valid schema for a node
308       auto schema = node->maybeSchema();
309       return schema && schema->is_mutable();
310     } else {
311       return getOrCreateAliasDb()->writesToWildcard(node);
312     }
313   }
314 
hasSideEffects(Node * node)315   bool hasSideEffects(Node* node) {
316     auto it = memo_.find(node);
317     if (it != memo_.end())
318       return it->second;
319     bool has_side_effects = node->hasSideEffects() ||
320         std::any_of(node->blocks().begin(),
321                     node->blocks().end(),
322                     [&](Block* b) {
323                       return std::any_of(
324                           b->nodes().begin(), b->nodes().end(), [&](Node* n) {
325                             return hasSideEffects(n);
326                           });
327                     }) ||
328         hasUntrackedMutation(node);
329 
330     memo_.emplace(node, has_side_effects);
331     return has_side_effects;
332   }
333 
removeDeadBlockOutputs(Node * node)334   void removeDeadBlockOutputs(Node* node) {
335     if (node->kind() != prim::If && node->kind() != prim::GradOf) {
336       return;
337     }
338 
339     for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
340       size_t i = i_1 - 1;
341       if (!node->outputs().at(i)->hasUses()) {
342         GRAPH_UPDATE(
343             "Dead ",
344             i,
345             "-th output ",
346             node->outputs().at(i)->debugName(),
347             " of node ",
348             node->kind().toQualString(),
349             " will be removed");
350         node->eraseOutput(i);
351         for (Block* b : node->blocks()) {
352           GRAPH_UPDATE(
353               "\tCorresponding block output ",
354               b->outputs().at(i)->debugName(),
355               " will be removed");
356           b->eraseOutput(i);
357         }
358       }
359     }
360   }
361 
removeDeadLoopOutputs(Node * node)362   void removeDeadLoopOutputs(Node* node) {
363     if (node->kind() != prim::Loop)
364       return;
365     auto loop_body = node->blocks().at(0);
366     auto loop_input_offset = 2; // offset of loop carried deps in input list
367     auto loop_body_offset =
368         1; // offset to the loop carried dependencies in block inputs/outputs
369 
370     for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
371       size_t i = i_1 - 1;
372       if (!node->outputs().at(i)->hasUses() &&
373           !loop_body->inputs().at(loop_body_offset + i)->hasUses()) {
374         logDeadLoopOutputs(node, i, loop_input_offset, loop_body_offset);
375         node->eraseOutput(i);
376         node->removeInput(loop_input_offset + i);
377         loop_body->eraseInput(loop_body_offset + i);
378         loop_body->eraseOutput(loop_body_offset + i);
379       }
380     }
381   }
382 
logDeadLoopOutputs(Node * node,size_t i,size_t loop_input_offset,size_t loop_body_offset)383   void logDeadLoopOutputs(
384       Node* node,
385       size_t i,
386       size_t loop_input_offset,
387       size_t loop_body_offset) {
388     auto loop_body = node->blocks().at(0);
389     GRAPH_UPDATE(
390         "Dead ",
391         loop_input_offset + i,
392         "-th input ",
393         node->inputs().at(i)->debugName(),
394         " will be removed");
395     GRAPH_UPDATE(
396         "Dead ",
397         i,
398         "-th output ",
399         node->outputs().at(i)->debugName(),
400         " will be removed");
401     GRAPH_UPDATE(
402         "\tDead block input ",
403         loop_body->inputs().at(loop_body_offset + i)->debugName(),
404         "at offset ",
405         loop_body_offset + i,
406         " will be removed");
407     GRAPH_UPDATE(
408         "\tDead block output ",
409         loop_body->outputs().at(loop_body_offset + i)->debugName(),
410         "at offset ",
411         loop_body_offset + i,
412         " will be removed");
413   }
414 
getOrCreateAliasDb()415   AliasDb* getOrCreateAliasDb() {
416     if (!aliasDb_) {
417       aliasDb_ = std::make_unique<AliasDb>(graph_);
418     }
419     return aliasDb_.get();
420   }
421 
422   DCESideEffectPolicy sideEffectPolicy_;
423 
424   std::shared_ptr<Graph> graph_;
425   bool useAliasDb_ = false;
426   // lazily initialized
427   std::unique_ptr<AliasDb> aliasDb_ = nullptr;
428   std::unordered_map<Node*, bool> memo_;
429   std::unordered_set<Node*> marked_;
430   std::unordered_set<const Value*> liveValues_;
431   std::function<void(const std::unordered_set<const Value*>&)> deleteCallback_ =
__anonb8bf968d0302(const std::unordered_set<const Value*>&) 432       [](const std::unordered_set<const Value*>&) {};
433 };
434 
EliminateDeadCode(const std::shared_ptr<Graph> & graph,DCESideEffectPolicy sideEffectPolicy)435 void EliminateDeadCode(
436     const std::shared_ptr<Graph>& graph,
437     DCESideEffectPolicy sideEffectPolicy) {
438   DeadCodeEliminator(graph, sideEffectPolicy)
439       .run(graph->block(), /*recurse=*/true);
440   GRAPH_DUMP("After EliminateDeadCode: ", graph);
441 }
442 
EliminateDeadCode(Block * block,bool recurse,DCESideEffectPolicy sideEffectPolicy)443 void EliminateDeadCode(
444     Block* block,
445     bool recurse,
446     DCESideEffectPolicy sideEffectPolicy) {
447   DeadCodeEliminator(sideEffectPolicy).run(block, recurse);
448 }
449 
EliminateDeadCode(Block * block,std::function<void (const std::unordered_set<const Value * > &)> cb,DCESideEffectPolicy sideEffectPolicy)450 void EliminateDeadCode(
451     Block* block,
452     std::function<void(const std::unordered_set<const Value*>&)> cb,
453     DCESideEffectPolicy sideEffectPolicy) {
454   DeadCodeEliminator eliminator(sideEffectPolicy);
455   eliminator.setDeleteCallback(std::move(cb));
456   eliminator.run(block, /*recurse=*/true);
457 }
458 
459 } // namespace torch::jit
460