xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/liveness.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/liveness.h>
2 
3 #include <torch/csrc/jit/ir/alias_analysis.h>
4 #include <torch/csrc/jit/ir/ir_views.h>
5 #include <torch/csrc/jit/passes/constant_pooling.h>
6 #include <iostream>
7 #include <memory>
8 
9 namespace torch::jit {
10 
11 // LivenessAnalyzer computes "bailout" liveness which is equivalent to
12 // "{LIVE_IN} or {GEN}" or "{LIVE_OUT} - {KILL}"
13 struct LivenessAnalyzer {
LivenessAnalyzertorch::jit::LivenessAnalyzer14   explicit LivenessAnalyzer(std::shared_ptr<Graph> graph)
15       : graph_(std::move(graph)) {}
16 
runtorch::jit::LivenessAnalyzer17   std::unordered_map<Node*, std::vector<Value*>> run() {
18     std::vector<Node*> counters;
19     insertExplicitUsesOfLoopCounters(graph_->block(), counters);
20 
21     // we implement the canonical fixed-point liveness
22     // the analysis is run until there are no more changes
23     // to liveness sets for each node
24     do {
25       changed_ = false;
26       processBlock(graph_->block(), SparseBitVector{});
27     } while (changed_);
28 
29     removeCounterNodes(counters);
30     std::unordered_map<Node*, std::vector<Value*>> result;
31 
32     for (const auto& e : liveness_sets_) {
33       result.insert({e.first, toValueVector(e.second)});
34     }
35     return result;
36   }
37 
38   // temporary make loop counts live for the duration of the loop
39   // as they are needed by BailOuts in the loop
insertExplicitUsesOfLoopCounterstorch::jit::LivenessAnalyzer40   void insertExplicitUsesOfLoopCounters(
41       Block* b,
42       std::vector<Node*>& counters) {
43     for (auto it : b->nodes()) {
44       if (it->kind() == prim::Loop) {
45         LoopView lv(it);
46         WithInsertPoint guard(lv.bodyBlock());
47         auto ctc = graph_->create(prim::Store, {lv.currentTripCount()}, 0);
48         graph_->insertNode(ctc);
49         counters.push_back(ctc);
50         auto mtc = graph_->create(prim::Store, {lv.maxTripCount()}, 0);
51         graph_->insertNode(mtc);
52         counters.push_back(mtc);
53       }
54 
55       for (auto ib : it->blocks()) {
56         insertExplicitUsesOfLoopCounters(ib, counters);
57       }
58     }
59   }
60 
removeCounterNodestorch::jit::LivenessAnalyzer61   void removeCounterNodes(std::vector<Node*>& counters) {
62     for (auto n : counters) {
63       n->destroy();
64     }
65   }
66 
dumptorch::jit::LivenessAnalyzer67   void dump(
68       const std::unordered_map<Node*, std::vector<Value*>>& liveness_sets) {
69     std::cout << "Liveness info:\n";
70     for (auto e : liveness_sets) {
71       if (!e.first->outputs().empty()) {
72         std::cout << e.first->outputs()[0]->debugName();
73       }
74 
75       std::cout << " " << e.first->kind().toQualString();
76       std::cout << " = ";
77       dump(e.second);
78       std::cout << '\n';
79     }
80     std::cout << "graph :\n";
81     graph_->dump();
82   }
83 
dumptorch::jit::LivenessAnalyzer84   void dump(const std::vector<Value*>& set) {
85     bool first = true;
86     std::cout << "[";
87     for (auto el : set) {
88       if (first) {
89         first = false;
90       } else {
91         std::cout << ", ";
92       }
93       std::cout << el->debugName() << "(" << el->unique() << ")";
94     }
95     std::cout << "]";
96   }
97 
98  private:
toSparseBitVectortorch::jit::LivenessAnalyzer99   SparseBitVector toSparseBitVector(at::ArrayRef<Value*> values) {
100     SparseBitVector sbv;
101     for (auto v : values) {
102       ids_to_values_[v->unique()] = v;
103       sbv.set(v->unique());
104     }
105     return sbv;
106   }
107 
toValueVectortorch::jit::LivenessAnalyzer108   std::vector<Value*> toValueVector(const SparseBitVector& sbv) {
109     std::vector<Value*> vec;
110     for (auto id : sbv) {
111       vec.push_back(ids_to_values_[id]);
112     }
113     return vec;
114   }
115 
processBlocktorch::jit::LivenessAnalyzer116   SparseBitVector processBlock(Block* b, SparseBitVector liveness) {
117     // block outputs are the uses
118     auto block_outputs = toSparseBitVector(b->outputs());
119     liveness |= block_outputs;
120 
121     SparseBitVector defs;
122     for (Node* it : b->nodes().reverse()) {
123       // kill outputs
124       liveness -= toSparseBitVector(it->outputs());
125       if (it->kind() == prim::Loop) {
126         LoopView lv(it);
127         // N.B. merge in changes from the loop header
128         auto loop_header = *lv.bodyBlock()->nodes().begin();
129         auto loop_block = liveness | liveness_sets_[loop_header];
130         loop_block = processBlock(lv.bodyBlock(), loop_block);
131         // loop block's inputs die outside loop's block
132         loop_block -= toSparseBitVector(lv.bodyBlock()->inputs());
133         liveness |= loop_block;
134       } else if (it->kind() == prim::If) {
135         IfView iv(it);
136         auto true_liveness = processBlock(iv.thenBlock(), liveness);
137         auto false_liveness = processBlock(iv.elseBlock(), liveness);
138         liveness |= true_liveness;
139         liveness |= false_liveness;
140       }
141       liveness |= toSparseBitVector(it->inputs());
142       // `|=` returns true if new bits were set in LHS
143       // after or/union with `liveness`
144       auto changed = liveness_sets_[it] |= liveness;
145       changed_ = changed_ | changed;
146     }
147     return liveness;
148   }
149 
150   std::shared_ptr<Graph> graph_;
151   bool changed_{false};
152   std::map<Node*, SparseBitVector> liveness_sets_;
153   std::map<size_t, Value*> ids_to_values_;
154 };
155 
BuildLivenessSets(std::shared_ptr<Graph> graph)156 std::unordered_map<Node*, std::vector<Value*>> BuildLivenessSets(
157     std::shared_ptr<Graph> graph) {
158   LivenessAnalyzer la(std::move(graph));
159   return la.run();
160 }
161 
162 } // namespace torch::jit
163