xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/graph_executor_impl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <torch/csrc/jit/runtime/graph_executor.h>
3 
4 #include <ATen/core/ivalue.h>
5 #include <c10/util/Exception.h>
6 #include <torch/csrc/autograd/grad_mode.h>
7 #include <torch/csrc/jit/frontend/tracer.h>
8 #include <torch/csrc/jit/ir/ir.h>
9 #include <torch/csrc/jit/passes/shape_analysis.h>
10 #include <torch/csrc/jit/resource_guard.h>
11 #include <torch/csrc/jit/runtime/argument_spec.h>
12 #include <torch/csrc/jit/runtime/autodiff.h>
13 #include <torch/csrc/jit/runtime/custom_operator.h>
14 #include <torch/csrc/jit/runtime/interpreter.h>
15 #include <torch/csrc/jit/runtime/profiling_record.h>
16 
17 #include <torch/csrc/autograd/edge.h>
18 #include <torch/csrc/autograd/function.h>
19 #include <torch/csrc/jit/frontend/ir_emitter.h>
20 #include <torch/csrc/jit/runtime/logging.h>
21 
22 #include <cstdint>
23 #include <iterator>
24 #include <memory>
25 #include <mutex>
26 #include <unordered_map>
27 #include <utility>
28 #include <vector>
29 
30 namespace torch::jit {
31 
32 void packGradient(const Gradient& gradient, Node* dnode);
33 bool needsGradient(const std::shared_ptr<const Graph>& graph);
34 void runOptimization(
35     std::shared_ptr<Graph>& graph,
36     bool unroll_non_constant_loops = true,
37     bool const_prop_user_classes = true);
38 void runNondiffOptimization(
39     std::shared_ptr<Graph>& graph,
40     bool strict_fuser_check = false);
41 void debugSetAutodiffSubgraphInlining(bool state);
42 bool TORCH_API getAutodiffSubgraphInlining();
43 
44 void debugSetFusionGroupInlining(bool state);
45 bool getFusionGroupInlining();
46 
47 // Tunable parameters for deciding when to create/keep subgraphs of
48 // differentiable code
49 const size_t autodiffSubgraphNodeThreshold = 2;
50 const size_t autodiffSubgraphInlineThreshold = 5;
51 
52 // a Graph can be created via tracing, or via a language-based frontend
53 // GraphExecutor runs it. It can run the same graph on many different sizes
54 // and different requires_grad states, and handles specializations for each
55 // situation. GraphExecutor is completely unaware of tracing or module
56 // parameters to keep the tracing concerns separated.
57 struct GraphExecutorImplBase {
prepareGraphGraphExecutorImplBase58   static std::shared_ptr<Graph> prepareGraph(
59       const std::shared_ptr<Graph>& graph) {
60     auto copy = graph->copy();
61     EraseShapeInformation(copy);
62     return copy;
63   }
64 
GraphExecutorImplBaseGraphExecutorImplBase65   GraphExecutorImplBase(
66       const std::shared_ptr<Graph>& graph,
67       std::string function_name)
68       : graph(prepareGraph(graph)),
69         function_name_(std::move(function_name)),
70         num_inputs(this->graph->inputs().size()),
71         num_outputs(this->graph->outputs().size()) {}
72 
73   // entry point where execution begins
74   void run(Stack& stack);
75   c10::intrusive_ptr<Future> runAsync(
76       Stack& stack,
77       TaskLauncher taskLauncher = at::launch);
78 
79   virtual const ExecutionPlan& getPlanFor(
80       Stack& stack,
81       std::optional<size_t> remaining_bailout_depth = std::nullopt) = 0;
82   virtual GraphExecutorState getDebugState() = 0;
83   virtual ~GraphExecutorImplBase() = default;
84 
isOptimizedGraphExecutorImplBase85   virtual bool isOptimized() const {
86     return false;
87   }
88 
89  protected:
90   friend struct GraphExecutor;
91 
92   // The unoptimized starting graph. This field is effectively const, but we
93   // can't make it so because Graph::copy() is not const (and making it const is
94   // not that easy at this point).
95   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
96   std::shared_ptr<Graph> graph;
97   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
98   std::string function_name_;
99 
100   // If false, we'll run the graph as we get it, without any optimizations.
101   // Useful for debugging.
102   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
103   const size_t num_inputs;
104   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
105   const size_t num_outputs;
106 
107   // GraphExecutors can be accessed from multiple threads, so this thread needs
108   // to be held every time we access the fallback or plan_cache.
109   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
110   std::mutex compile_mutex;
111 };
112 
113 } // namespace torch::jit
114