xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/dead_code_elimination.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 
5 namespace torch::jit {
6 
7 // If given a top-level graph, DCE will construct do alias analysis that allows
8 // for "smarter" dead code elimination (we will eliminate mutable ops if we can
9 // prove the mutated values are not used). Otherwise, we will not allow DCE to
10 // eliminate mutable ops.
11 //
12 // So, prefer to use the graph version if you can.
13 enum class DCESideEffectPolicy : uint8_t {
14   // default behavior: dead code elimination will check if a node has side
15   // effects
16   // and not delete it if it does.
17   DONT_DELETE_NODES_WITH_SIDE_EFFECTS,
18   // with this flag, dead code elimination will not check if a node has side
19   // effects and treat nodes with side effects like any other node,
20   // i.e. delete them if their outputs aren't used anywhere.
21   ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS
22 };
23 
24 TORCH_API void EliminateDeadCode(
25     const std::shared_ptr<Graph>& graph,
26     DCESideEffectPolicy sideEffectPolicy =
27         DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS);
28 TORCH_API void EliminateDeadCode(
29     Block* block,
30     bool recurse = true,
31     DCESideEffectPolicy sideEffectPolicy =
32         DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS);
33 
34 // Invoke the user-provided callback on all live values before deleting anything
35 TORCH_API void EliminateDeadCode(
36     Block* block,
37     std::function<void(const std::unordered_set<const Value*>&)> cb,
38     DCESideEffectPolicy sideEffectPolicy =
39         DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS);
40 } // namespace torch::jit
41