xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/engine.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // Engine implements backpropagation from output variables and their gradients
4 // to "root" variables (variables created by the user with requires_grad=True).
5 
6 #include <ATen/Tensor.h>
7 #include <ATen/ThreadLocalState.h>
8 #include <ATen/core/ivalue.h>
9 #include <torch/csrc/Export.h>
10 #include <torch/csrc/autograd/anomaly_mode.h>
11 #include <torch/csrc/autograd/function.h>
12 #include <torch/csrc/autograd/functions/basic_ops.h>
13 #include <torch/csrc/autograd/graph_task.h>
14 #include <torch/csrc/autograd/input_buffer.h>
15 #include <torch/csrc/autograd/saved_variable_hooks.h>
16 #include <torch/csrc/autograd/utils/warnings.h>
17 
18 #include <c10/util/CallOnce.h>
19 
20 #include <exception>
21 #include <functional>
22 #include <memory>
23 #include <queue>
24 #include <utility>
25 #include <vector>
26 
27 namespace torch::autograd {
28 struct ReadyQueue;
29 }
30 
31 namespace torch::autograd {
32 
33 // Maximum reentrant backward depth before switching to a new thread
34 // This limit is based on the TSAN's deadlock detector, where it will
35 // fail if a program hold more than 65 locks in one thread at once.
36 // As we hold mutex in every of our custom C++ autograd Node, we would
37 // like to avoid TSAN complains on this when doing reentrant backwards
38 // For reference, see https://github.com/google/sanitizers/issues/950
39 static constexpr int MAX_DEPTH = 60;
40 
41 void set_device(int device);
42 TORCH_API void validate_outputs(
43     const edge_list& edges,
44     variable_list& grads,
45     const std::function<std::string(const std::string&)>& format_error);
46 
47 struct NodeTask {
48   std::weak_ptr<GraphTask> base_;
49   std::shared_ptr<Node> fn_;
50   // This buffer serves as an implicit "addition" node for all of the
51   // gradients flowing here.  Once all the dependencies are finished, we
52   // use the contents of this buffer to run the function.
53   InputBuffer inputs_;
54   // When worker receives a task with isShutdownTask = true, it will immediately
55   // exit. The engine sends a shutdown task to every queue upon its destruction.
56   bool isShutdownTask_;
57 
58   int getReentrantDepth() const;
59 
60   NodeTask(
61       std::weak_ptr<GraphTask> base,
62       std::shared_ptr<Node> fn,
63       InputBuffer inputs,
64       bool isShutdownTask = false)
base_NodeTask65       : base_(std::move(base)),
66         fn_(std::move(fn)),
67         inputs_(std::move(inputs)),
68         isShutdownTask_(isShutdownTask) {}
69 };
70 
71 // Guard that sets and restores checkpoint_valid
72 class CheckpointValidGuard {
73  public:
74   explicit CheckpointValidGuard(
75       const std::shared_ptr<const GraphTask>& graph_task);
76   ~CheckpointValidGuard();
77 
78  private:
79   bool prev_checkpoint_valid_state;
80 };
81 
82 struct ReadyQueue {
83  private:
84   // Returns true when t2 should be (weakly) BEFORE t1 in the queue.
85   // Shutdown tasks are first and then empty NodeTask are next.
86   struct CompareNodeTaskTime {
operatorReadyQueue::CompareNodeTaskTime87     bool operator()(NodeTask const& t1, NodeTask const& t2) {
88       // NOLINTNEXTLINE(bugprone-branch-clone)
89       if (t2.isShutdownTask_) {
90         return true;
91       } else if (!t1.fn_ || t1.isShutdownTask_) {
92         return false;
93       } else if (!t2.fn_) {
94         return true;
95       } else if (t1.getReentrantDepth() == t2.getReentrantDepth()) {
96         return t1.fn_->sequence_nr() < t2.fn_->sequence_nr();
97       } else {
98         return t1.getReentrantDepth() < t2.getReentrantDepth();
99       }
100     }
101   };
102 
103   // To notify threads waiting on the ReadyQueue of available tasks on the heap_
104   std::condition_variable not_empty_;
105   // To protect read and writes to heap_
106   mutable std::mutex mutex_;
107 
108   std::priority_queue<NodeTask, std::vector<NodeTask>, CompareNodeTaskTime>
109       heap_;
110 
111  public:
112   // incrementOutstandingTasks indicates whether or not we should increment
113   // 'outstanding_tasks_' for the associated GraphTask. This should mostly
114   // always be true and is only set false in certain cases (see docs for
115   // DistEngine.execute_graph_task_until_ready_queue_empty)
116   void push(NodeTask item, bool incrementOutstandingTasks = true);
117   void pushShutdownTask();
118   NodeTask pop();
119   bool empty() const;
120   size_t size() const;
121 };
122 
123 // A single instance of this struct should be created through the whole process
124 // lifetime. The worker thread creation logic and Engine's destructor rely on
125 // this.
126 struct TORCH_API Engine {
127   /// Returns a reference to a static `Engine` instance.
128   static Engine& get_default_engine();
129 
130   static Engine& get_base_engine();
131 
132   // compiled_autograd needs to live in a different .so file so that it
133   // can have python symbols, so we add a layer of indirection
134   // see [Note: Compiled Autograd]
135   typedef variable_list (*compiled_autograd_fn)(
136       const std::shared_ptr<Node>& graph_root,
137       GraphTask& graph_task,
138       bool accumulate_grad,
139       const edge_list& outputs);
140   static void set_compiled_autograd(compiled_autograd_fn fn);
141 
142   Engine(const Engine&) = delete;
143   Engine(Engine&&) = delete;
144   virtual ~Engine();
145 
146   // Given a list of (Node, input number) pairs computes the value of the graph
147   // by following next_edge references.
148   virtual variable_list execute(
149       const edge_list& roots,
150       const variable_list& inputs,
151       bool keep_graph,
152       bool create_graph,
153       bool accumulate_grad,
154       const edge_list& outputs = {});
155 
156   // Given a pre-populated GraphTask and GraphRoot, computes the backward pass
157   // for the graph.
158   //
159   // NB: This API should only be used by internal autograd specific
160   // machinery and shouldn't be exposed to users in anyway.
161   virtual c10::intrusive_ptr<at::ivalue::Future> execute_with_graph_task(
162       const std::shared_ptr<GraphTask>& graph_task,
163       std::shared_ptr<Node> graph_root,
164       InputBuffer&& input_buffer);
165 
make_anomaly_metadataEngine166   virtual std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() {
167     return std::make_unique<AnomalyMetadata>();
168   }
169 
get_default_saved_variable_hooksEngine170   virtual std::unique_ptr<SavedVariableHooks> get_default_saved_variable_hooks() {
171     return nullptr;
172   }
173 
174   // We pass cpu_ready_queue to evaluate_function, so that it knows
175   // the correct ready queue to push to after a NodeTask is ready
176   void evaluate_function(
177       std::shared_ptr<GraphTask>& graph_task,
178       Node* func,
179       InputBuffer& inputs,
180       const std::shared_ptr<ReadyQueue>& cpu_ready_queue);
181 
182   void initialize_device_threads_pool();
183   virtual void thread_on_exception(
184       const std::shared_ptr<GraphTask>& graph_task,
185       const std::shared_ptr<Node>& fn,
186       std::exception& e);
187 
188   void queue_callback(std::function<void()> callback);
189 
190   bool is_checkpoint_valid();
191 
192   // Should be called after fork to notify that worker threads are gone
193   void release_workers();
194 
195   // Must be called by subclass before destructing to avoid a data-race-on-vptr.
196   void stop();
197 
198   // Initializes a device thread for the autograd engine.
199   virtual void thread_init(
200       int device,
201       const std::shared_ptr<ReadyQueue>& ready_queue,
202       bool should_increment = true);
203 
204  protected:
205   Engine();
206   void compute_dependencies(Node* root, GraphTask& task, uint64_t min_topo_nr);
207 
208   // initialize the thread local ready queue with the ready queue that is
209   // created elsewhere (i.e. thread_init, Engine::execute, etc), or create a new
210   // ready queue if ready_queue is not provided.
211   void init_local_ready_queue(
212       std::shared_ptr<ReadyQueue> ready_queue = nullptr);
213 
214   std::shared_ptr<ReadyQueue> ready_queue(
215       std::shared_ptr<ReadyQueue> cpu_ready_queue,
216       at::Device device);
217   std::shared_ptr<ReadyQueue> ready_queue_by_index(
218       std::shared_ptr<ReadyQueue> cpu_ready_queue,
219       int device_index);
220   // start device threads (CUDA, XLA, etc.) in Engine,
221   // note that it does NOT start CPU thread.
222   void start_device_threads();
223   void increment_non_reentrant_thread_count();
224   void decrement_non_reentrant_thread_count();
225   virtual void thread_main(const std::shared_ptr<GraphTask>& task);
226   void reentrant_thread_init();
227   void add_thread_pool_task(const std::weak_ptr<GraphTask>& graph_task);
228 
229   // Ensures device_ready_queues_ are initialized only once
230   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
231   c10::once_flag start_device_threads_flag_;
232   // Safe to read device_ready_queues_ without synchronization after
233   // initialization
234   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
235   std::vector<std::shared_ptr<ReadyQueue>> device_ready_queues_;
236 
237   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
238   std::vector<std::function<void()>> final_callbacks_;
239   // To protect reads and writes to final_callbacks_
240   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
241   std::mutex post_callbacks_lock_;
242 
243   // How many nested reentrant calls are allowed until a new thread is used
244   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
245   int max_recursion_depth_;
246 
247   struct ThreadPoolShared {
248     // Data structures used by the threads for executing reentrant backwards
249     // tasks. See Note [Reentrant backwards]
250     // Number of available threads for processing new GraphTasks.
251     unsigned int num_workers_{0};
252     // The threads will wait on work_ to be notified of GraphTasks
253     std::condition_variable work_;
254     // To protect reads and writes to graphtask_queue_ and num_workers_
255     // and for synchronizing creating new threads when needed
256     std::mutex mutex_;
257     // Workers will process the GraphTasks added to this queue. A GraphTask is
258     // allocated inside Engine::execute and lives for the duration of execute
259     std::queue<std::weak_ptr<GraphTask>> graphtasks_queue_;
260 
261     ThreadPoolShared() = default;
262   };
263 
264   // Temporary workaround until shutting down threads is done
265   // We need shared ownership of all these objects because the threads are
266   // leaked when Engine shuts down, so there may be threads waiting on work_ for
267   // the graphtasks_queue_ to be nonempty.
268   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
269   std::shared_ptr<ThreadPoolShared> thread_pool_shared_;
270 
271  private:
272   // Number of non-reentrant threads
273   std::atomic<uint32_t> non_reentrant_device_thread_count_;
274   // Destructor will wait for non-reentrant threads to finish
275   std::condition_variable non_reentrant_device_thread_condvar_;
276   std::mutex non_reentrant_device_thread_mutex_;
277   // stop() must be called before the destruction path goes down to the base
278   // class, in order to avoid a data-race-on-vptr. Use this boolean to guard
279   // whether stop() has already been called, so we can call this in every
280   // destructor of the class hierarchy.
281   bool stopped_{false};
282 };
283 
284 // allow python_engine to override the default engine when it loads
285 using EngineStub = Engine& (*)();
286 TORCH_API void set_default_engine_stub(EngineStub stub);
287 
288 } // namespace torch::autograd
289