xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/bailout_graph.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/bailout_graph.h>
2 
3 #include <ATen/core/function.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/ir/alias_analysis.h>
6 #include <torch/csrc/jit/ir/ir_views.h>
7 #include <torch/csrc/jit/jit_log.h>
8 #include <torch/csrc/jit/passes/clear_profiling.h>
9 #include <torch/csrc/jit/passes/constant_pooling.h>
10 #include <torch/csrc/jit/passes/liveness.h>
11 #include <memory>
12 #include <unordered_set>
13 #include <utility>
14 
15 namespace torch::jit {
16 
shouldBeCapturedInByBailOut(Node * n)17 static bool shouldBeCapturedInByBailOut(Node* n) {
18   return n->kind() != prim::Constant;
19 }
20 
21 struct BailOutGraphBuilderForNode {
BailOutGraphBuilderForNodetorch::jit::BailOutGraphBuilderForNode22   explicit BailOutGraphBuilderForNode(
23       std::shared_ptr<Graph> graph,
24       std::shared_ptr<Graph> target)
25       : graph_(std::move(graph)), copy_graph_(std::move(target)) {}
26 
27   // capture `old_value` into the bailout graph
28   // by creating a new input and mapping
29   // `old_value` to it
addNewInputForValuetorch::jit::BailOutGraphBuilderForNode30   Value* addNewInputForValue(Value* old_value) {
31     auto node = old_value->node();
32     // this reduces the number of inputs to a bailout graph significantly
33     // making it easier to debug
34     if (node->kind() == prim::Constant) {
35       TORCH_INTERNAL_ASSERT(!shouldBeCapturedInByBailOut(node));
36       auto new_const = copy_graph_->createClone(node, {nullptr});
37       copy_graph_->block()->prependNode(new_const);
38       return new_const->output();
39     }
40 
41     live_inputs_.push_back(old_value);
42     auto new_value = copy_graph_->block()->addInput();
43     GRAPH_DEBUG(
44         "Adding a new value %",
45         new_value->debugName(),
46         " for %",
47         old_value->debugName());
48     return mapValueAndCopyMetadata(old_value, new_value);
49   }
50 
mapValueAndCopyMetadatatorch::jit::BailOutGraphBuilderForNode51   Value* mapValueAndCopyMetadata(Value* old_value, Value* new_value) {
52     this->old_to_new_[old_value] = new_value;
53     new_value->copyMetadata(old_value);
54     return new_value;
55   }
56 
getOrAddInputForValuetorch::jit::BailOutGraphBuilderForNode57   Value* getOrAddInputForValue(Value* v) {
58     if (this->old_to_new_.count(v) == 0) {
59       return addNewInputForValue(v);
60     } else {
61       return this->old_to_new_[v];
62     }
63   }
64 
getInputForValuetorch::jit::BailOutGraphBuilderForNode65   Value* getInputForValue(Value* v) {
66     TORCH_INTERNAL_ASSERT(this->old_to_new_.count(v));
67     return this->old_to_new_[v];
68   }
69 
cloneNodetorch::jit::BailOutGraphBuilderForNode70   Node* cloneNode(Node* node) {
71     auto* block = copy_graph_->block();
72     auto env = [this](Value* v) { return getOrAddInputForValue(v); };
73 
74     auto new_node = block->appendNode(copy_graph_->createClone(node, env));
75     for (size_t i = 0; i < node->outputs().size(); ++i) {
76       auto oo = node->outputs()[i];
77       auto no = new_node->outputs()[i];
78       old_to_new_[oo] = no;
79     }
80 
81     return new_node;
82   }
83 
84   // buildBailOutBlockFrom builds a bailout graph from
85   // a given node `n` until the end of the owning block
86   // If `n` belongs to `prim::If` or `prim::Loop`
87   // buildBailOutLoop/If continue
88   // from block's owning node (e.g. `prim::If` or
89   // `prim::Loop`)
buildBailOutBlockFromtorch::jit::BailOutGraphBuilderForNode90   void buildBailOutBlockFrom(Node* n) {
91     auto b = n->owningBlock();
92     for (auto it = n->iterator(); it != b->nodes().end(); it++) {
93       cloneNode(*it);
94     }
95 
96     // we are either in `prim::If` or `prim::Loop`
97     // bailout graph building will continue from `outer_node` next
98     auto outer_node = n->owningBlock()->owningNode();
99     if (outer_node) {
100       if (outer_node->kind() == prim::Loop) {
101         buildBailOutLoop(outer_node);
102       } else if (outer_node->kind() == prim::If) {
103         buildBailOutIf(b->outputs(), outer_node);
104       } else {
105         AT_ERROR("Unexpected outer node");
106       }
107     }
108   }
109 
mapValuestorch::jit::BailOutGraphBuilderForNode110   void mapValues(
111       const at::ArrayRef<Value*> block_outputs,
112       const at::ArrayRef<Value*> carried_deps) {
113     TORCH_INTERNAL_ASSERT(block_outputs.size() == carried_deps.size());
114     for (const auto i : c10::irange(block_outputs.size())) {
115       auto nv = getOrAddInputForValue(block_outputs[i]);
116       old_to_new_[carried_deps[i]] = nv;
117     }
118   }
119 
buildBailOutLooptorch::jit::BailOutGraphBuilderForNode120   void buildBailOutLoop(Node* outer_node) {
121     LoopView lv(outer_node);
122     auto old_max_count = getOrAddInputForValue(lv.maxTripCount());
123     auto cur_iter = getInputForValue(lv.currentTripCount());
124     auto block_outputs = lv.bodyBlock()->outputs();
125 
126     auto* block = copy_graph_->block();
127     // subtract the number of iterations
128     WithInsertPoint guard(*block->nodes().end());
129     auto updated_max_trip_count =
130         copy_graph_->insert(aten::sub, {old_max_count, cur_iter});
131     auto one = copy_graph_->insertConstant({1});
132     updated_max_trip_count =
133         copy_graph_->insert(aten::sub, {updated_max_trip_count, one});
134     auto cur_plus_one = copy_graph_->insert(aten::add, {one, cur_iter});
135 
136     // We need to be careful when mapping `block_outputs` to continuation
137     // loop's inputs since `cloneFrom` will replace `%4` with the same value
138     // in both, `prim::Loop` and `aten::cat` in the example below:
139     //
140     // ... : Tensor = prim::Loop(%MAX_TRIP_COUNT, %COND, ..., %4)
141     //   block0(%i.2 : int, ...):
142     //     ...
143     //     %y.5 : Double(3) = aten::cat(%22, %4)
144     //     ...
145     //
146     // However for the cloned loop node, the values should be different.
147     // Namely, the value in `prim::Loop` should come from
148     // `lv.bodyBlock()->outputs()` which are mapped to the outputs of the
149     // current iteration whereas `%4` in `aten::cat` needs to be mapped to the
150     // cloned value of `%4` in a bailout graph. To work around this, we manually
151     // clone loop nodes
152 
153     // map the residual loop's inputs to the outputs of the current iteration
154     // (i.e. `block_outputs`)
155     auto new_loop =
156         copy_graph_->insertNode(copy_graph_->create(prim::Loop, {}, 0))
157             ->setSourceRange(outer_node->sourceRange());
158     new_loop->addInput(updated_max_trip_count);
159     for (auto bo : block_outputs) {
160       new_loop->addInput(getOrAddInputForValue(bo));
161     }
162 
163     // clone the loop body and map old loop's outputs to new loop's outputs
164     auto new_loop_body = new_loop->addBlock();
165     auto env = [this](Value* v) { return getOrAddInputForValue(v); };
166     new_loop_body->cloneFrom(lv.bodyBlock(), env);
167     for (auto ov : lv.carriedOutputs()) {
168       auto no = new_loop->addOutput();
169       mapValueAndCopyMetadata(ov, no);
170     }
171     LoopView new_lv(new_loop);
172     {
173       WithInsertPoint guard_in_loop(*new_lv.bodyBlock()->nodes().begin());
174       // `one` will be replaced with new_lv.currentTripCount()
175       // but it needs to be done after
176       // new_lv.currentTripCount()->replaceAllUsesWith(adj_iter_ctr);
177       // to avoid cyclical references
178       auto adj_iter_ctr = copy_graph_->insert(aten::add, {cur_plus_one, one});
179       new_lv.currentTripCount()->replaceAllUsesWith(adj_iter_ctr);
180       adj_iter_ctr->node()->replaceInputWith(one, new_lv.currentTripCount());
181     }
182 
183     if (outer_node->next()) {
184       buildBailOutBlockFrom(outer_node->next());
185     }
186   }
187 
buildBailOutIftorch::jit::BailOutGraphBuilderForNode188   void buildBailOutIf(
189       const at::ArrayRef<Value*> block_outputs,
190       Node* outer_node) {
191     auto if_outputs = outer_node->outputs();
192     mapValues(block_outputs, if_outputs);
193     buildBailOutBlockFrom(outer_node->next());
194   }
195 
buildBailOutGraphFromtorch::jit::BailOutGraphBuilderForNode196   std::shared_ptr<Graph> buildBailOutGraphFrom(Node* n) {
197     // add graph inputs for guard's input
198     // and loop counts for loops `n` is contained in
199     // to make sure we can line bailout grap's inputs up properly
200     // with arguments to this BailOut node.
201     for (auto bi : n->inputs()) {
202       getOrAddInputForValue(bi);
203     }
204 
205     buildBailOutBlockFrom(n);
206     // add graph outputs
207     for (auto ov : graph_->outputs()) {
208       copy_graph_->registerOutput(getOrAddInputForValue(ov));
209     }
210     return copy_graph_;
211   }
212 
213   std::shared_ptr<Graph> graph_;
214   std::shared_ptr<Graph> copy_graph_;
215   std::vector<Value*> live_inputs_;
216   std::unordered_map<Value*, Value*> old_to_new_;
217 };
218 
219 // `BailOutInserter` replaces prim::Guard nodes with
220 // prim::BailOut nodes that allow interpreter to
221 // resume execution of the unoptimized(deoptimized)
222 // version of an original graph from a particular point
223 struct BailOutInserter {
BailOutInsertertorch::jit::BailOutInserter224   explicit BailOutInserter(std::shared_ptr<Graph> graph)
225       : graph_(std::move(graph)) {}
226 
runtorch::jit::BailOutInserter227   void run() {
228     liveness_sets_ = BuildLivenessSets(graph_);
229     insertBailOuts(graph_->block());
230     replaceGuardsWithBailouts();
231     // embed a full original graph
232     addUnoptimizedFuncToBailouts();
233   }
234 
235   // Packs the original unoptimized graph into a Function constant
236   // and add it as the first input to every prim::BailOut point
237   // This graph will be used to compute a bailout graph for
238   // any given bailout point
addUnoptimizedFuncToBailoutstorch::jit::BailOutInserter239   void addUnoptimizedFuncToBailouts() {
240     auto unoptimized_graph = graph_->copy();
241     auto unopt_func = graph_->create(prim::BailoutTemplate)
242                           ->insertAfter(graph_->param_node());
243 
244     // Returns an int so that we have an easy way to do graph traversal
245     unopt_func->output()->setType(IntType::get());
246     unopt_func->g_(attr::Subgraph, std::move(unoptimized_graph));
247     for (auto bn : bailouts_) {
248       bn->insertInput(0, unopt_func->output());
249     }
250   }
251 
252   // Removes guards by hooking up the guarded tensor
253   // directly to its users and also clears
254   // profiling information on it.
removeGuardstorch::jit::BailOutInserter255   void removeGuards(Block* b) {
256     for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
257       if (it->kind() == prim::Guard) {
258         // this will need to be profiled again
259         it->input()->setType(TensorType::get());
260         // destroy the guard
261         it->output()->replaceAllUsesWith(it->input());
262         it.destroyCurrent();
263       }
264 
265       for (auto ib : it->blocks()) {
266         removeGuards(ib);
267       }
268     }
269   }
270 
271   // replace each prim::Guard
272   // with its corresponding prim::BailOut
replaceGuardsWithBailoutstorch::jit::BailOutInserter273   void replaceGuardsWithBailouts() {
274     for (auto e : replacements_) {
275       e.first->replaceAllUsesWith(e.second);
276       e.second->node()->insertAfter(e.first->node());
277       e.first->node()->destroy();
278     }
279   }
280 
281   // Inserts prim::BailOut nodes for every prim::Guard
282   // Each BailOut point takes the set of inputs live
283   // at that particular execution point.
284   // An input is live if it's used beyond the guard/BailOut
285   // point to compute graph's outputs
insertBailOutstorch::jit::BailOutInserter286   void insertBailOuts(Block* b) {
287     for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
288       if (it->kind() == prim::Guard) {
289         auto bailout_node = b->owningGraph()->create(prim::BailOut);
290         bailouts_.push_back(bailout_node);
291 
292         const auto& live_inputs = liveness_sets_[*it];
293 
294         // guarded inputs come first
295         // currently, there's always one guarded input
296         bailout_node->addInput(it->input());
297         for (auto li : live_inputs) {
298           // Guarded inputs have already been added
299           // Also, skip some inputs that BailOutGraphBuilder can
300           // materialize into bailout graphs directly
301           if (!shouldBeCapturedInByBailOut(li->node()) || li == it->input()) {
302             continue;
303           }
304           bailout_node->addInput(li);
305         }
306 
307         bailout_node->output()->setType(it->output()->type());
308         bailout_node->i_(attr::index, bailout_index_++);
309         // we can't immediately replace nodes since this action will corrupt
310         // the liveness sets of following BailOut nodes if any of their
311         // arguments are BailOut nodes themselves
312         replacements_.insert({it->output(), bailout_node->output()});
313 
314       } else {
315         for (auto ib : it->blocks()) {
316           insertBailOuts(ib);
317         }
318       }
319     }
320   }
321 
322   std::shared_ptr<Graph> graph_;
323   std::map<Node*, Node*> subgraphs;
324   std::size_t bailout_index_{0};
325   std::unordered_map<Node*, std::vector<Value*>> liveness_sets_;
326   std::vector<Node*> bailouts_;
327   std::map<Value*, Value*> replacements_;
328 };
329 
InsertBailOuts(std::shared_ptr<Graph> graph)330 void InsertBailOuts(std::shared_ptr<Graph> graph) {
331   BailOutInserter ibo(std::move(graph));
332   ibo.run();
333 }
334 
335 // linearly scans through graph's nodes to locate prim::BailOut whose
336 // index matches the given `index`
locateBailOutNodeInUnoptimizedGraph(Block * b,int64_t index)337 static Node* locateBailOutNodeInUnoptimizedGraph(Block* b, int64_t index) {
338   for (auto n : b->nodes()) {
339     if ((n->kind() == prim::BailOut || n->kind() == prim::Guard) &&
340         n->hasAttribute(attr::index) && n->i(attr::index) == index) {
341       return n;
342     }
343     for (auto ib : n->blocks()) {
344       if (auto bn = locateBailOutNodeInUnoptimizedGraph(ib, index)) {
345         return bn;
346       }
347     }
348   }
349   return nullptr;
350 }
351 
352 // Removes prim::BailOuts and hooks the guarded input directly
353 // to its users
removeBailouts(Block * b)354 static void removeBailouts(Block* b) {
355   for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
356     if (it->kind() == prim::BailOut || it->kind() == prim::Guard) {
357       // clear profiling information
358       it->inputs().at(0)->setType(TensorType::get());
359       it->output()->replaceAllUsesWith(it->inputs().at(0));
360       it.destroyCurrent();
361     } else {
362       for (auto ib : it->blocks()) {
363         removeBailouts(ib);
364       }
365     }
366   }
367 }
368 
369 // see `bailout_graph.h`
BuildBailOutGraphFrom(int64_t bailout_index,const std::shared_ptr<Graph> & orig,const std::shared_ptr<Graph> & target)370 TORCH_API std::shared_ptr<Graph> BuildBailOutGraphFrom(
371     int64_t bailout_index,
372     const std::shared_ptr<Graph>& orig,
373     const std::shared_ptr<Graph>& target) {
374   auto orig_bailout_node =
375       locateBailOutNodeInUnoptimizedGraph(orig->block(), bailout_index);
376 
377   GRAPH_DEBUG("bailout triggered for ", *orig_bailout_node);
378   GRAPH_DUMP("original bailout graph ", orig);
379   TORCH_INTERNAL_ASSERT(
380       orig_bailout_node->inputs().at(0)->type()->cast<FunctionType>() ==
381       nullptr);
382   TORCH_INTERNAL_ASSERT(
383       orig_bailout_node &&
384       (orig_bailout_node->kind() == prim::BailOut ||
385        orig_bailout_node->kind() == prim::Guard) &&
386       bailout_index == orig_bailout_node->i(attr::index));
387   BailOutGraphBuilderForNode bg(orig, target);
388   auto bailout_graph = bg.buildBailOutGraphFrom(orig_bailout_node);
389 
390   removeBailouts(bailout_graph->block());
391   ClearProfilingInformation(bailout_graph);
392   GRAPH_DUMP("bailout_graph ", bailout_graph);
393   return bailout_graph;
394 }
395 
396 } // namespace torch::jit
397