xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/bailout_graph.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <ATen/core/ivalue.h>
5 #include <ATen/core/jit_type.h>
6 #include <ATen/core/stack.h>
7 #include <torch/csrc/Export.h>
8 #include <torch/csrc/jit/ir/ir.h>
9 
10 #include <list>
11 #include <vector>
12 
13 namespace torch::jit {
14 
15 // Replaces prim::Guard nodes with prim::BailOut nodes and
16 // computes sets of inputs needed to resume execution at
17 // bailout points
18 TORCH_API void InsertBailOuts(std::shared_ptr<Graph> graph);
19 
20 // Builds a bailout graph into `target` (which is an empty graph)
21 // for a given bailout point `bailout_index`
22 // from the original graph `orig` (the original unoptimized graph)
23 // BailOut graphs allow Interpreter to resume
24 // execution of the (un/de)optimized graph (i.e.
25 // a graph that doesn't rely on any assumptions derived from
26 // on profiling information) from a given BailOut point
27 // should any of the assumptions fail for an actual input.
28 TORCH_API std::shared_ptr<Graph> BuildBailOutGraphFrom(
29     int64_t bailout_index,
30     const std::shared_ptr<Graph>& orig,
31     const std::shared_ptr<Graph>& target);
32 } // namespace torch::jit
33