xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/engine/dist_engine.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <mutex>
4 #include <unordered_set>
5 
6 #include <torch/csrc/autograd/engine.h>
7 #include <torch/csrc/autograd/function.h>
8 #include <torch/csrc/autograd/functions/basic_ops.h>
9 #include <torch/csrc/distributed/autograd/context/context.h>
10 
11 namespace torch {
12 namespace distributed {
13 namespace autograd {
14 
15 // Forward declaration.
16 class BackwardPassCleanupGuard;
17 
18 // This is a singleton class responsible for running distributed backward
19 // passes. This engine relies heavily on the vanilla autograd engine and tries
20 // to re-use it as much as possible. This class is mostly responsible for the
21 // distributed aspects of autograd and tries to hook into the autograd engine
22 // where convenient.
23 
24 // Unlike the vanilla autograd engine, the distributed autograd engine
25 // accumulates the gradients in the appropriate DistAutogradContext. This avoids
26 // multiple trainer nodes stomping on each others gradients.
27 class TORCH_API DistEngine {
28  public:
29   // Retrieve the singleton instance.
30   static DistEngine& getInstance();
31 
32   // Given a list of root variables, start the distributed backwards pass from
33   // these variables and accumulate all the gradients in the current autograd
34   // context on each node. This method is used to kickoff distributed autograd
35   // on a single node.
36   void execute(
37       int64_t context_id,
38       const torch::autograd::variable_list& roots,
39       bool retainGraph);
40 
41   // Given a send function to execute in the autograd engine, ensures we compute
42   // dependencies once for this node and enqueues the send function for execute
43   // in the engine.
44   // This method is used to kick off the autograd computation on a node when it
45   // receives gradients from the corresponding 'recv' method on another node.
46   // The gradients are accumulated in the provided autograd context.
47   c10::intrusive_ptr<c10::ivalue::Future> executeSendFunctionAsync(
48       const ContextPtr& autogradContext,
49       const std::shared_ptr<SendRpcBackward>& sendFunction,
50       bool retainGraph);
51 
52   // Number of backward passes currently running for the Distributed Engine.
53   size_t numBackwardPasses() const;
54 
55   // Returns key-value pairs consisting of useful debugging information related
56   // to distributed autograd.
57   std::unordered_map<std::string, int> getDebugInfo() const;
58 
59   DistEngine(const DistEngine&) = delete;
60   DistEngine& operator=(const DistEngine&) = delete;
61   DistEngine(DistEngine&&) = delete;
62   DistEngine& operator=(DistEngine&&) = delete;
63 
64  private:
65   // Make sure this is a singleton.
66   DistEngine();
67   ~DistEngine();
68 
69   // Validates the input roots for the backward computations and retrieves the
70   // appropriate root edges and corresponding gradients. Populates root_edges
71   // with the appropriate gradient edges and grads with the gradients for each
72   // edge.
73   void validateRootsAndRetrieveEdges(
74       const torch::autograd::variable_list& roots,
75       torch::autograd::edge_list& rootEdges,
76       torch::autograd::variable_list& grads);
77 
78   // Given the autograd context, root edges and grads, we compute dependencies
79   // for the local node and fill out the provided GraphTask and GraphRoot with
80   // appropriate information for the local autograd engine.
81   // We also determine all leaf nodes(functions) in the graph and accumulate
82   // them in outputEdges.
83   void computeDependencies(
84       const ContextPtr& context,
85       const torch::autograd::edge_list& rootEdges,
86       const torch::autograd::variable_list& grads,
87       const std::shared_ptr<torch::autograd::Node>& graphRoot,
88       torch::autograd::edge_list& outputEdges,
89       bool retainGraph);
90 
91   // Given a pre-populated GraphTask and a root node, compute the backward pass
92   // for the autograd graph until the graph task ready queue is empty.
93   //
94   // This method assumes that the appropriate GraphTask has already been
95   // initialized appropriately. It will construct a local ready queue to
96   // traverse the GraphTask instead of using the GraphTask embedded
97   // cpu_ready_queue, this is because dist engine might run the same GraphTask
98   // from different SendFunctions concurrently in different threads. The method
99   // will only mark the GraphTask as completed when it needs to, which means it
100   // might not mark as completed for every call as dist engine would like to
101   // keep the GraphTask alive when it not receives all gradients.
102   //
103   // When `incrementOutstandingTasks=false`, the function does not increment
104   // 'outstanding_tasks_' in the appropriate GraphTask. It is assumed we've
105   // already done this before hand for this task (to ensure we don't pre-mark
106   // this graph_task as completed). This is useful in the distributed autograd
107   // case where we need to increment 'outstanding_tasks_' first to indicate the
108   // local autograd engine the graph task is not completed until it receives the
109   // signals from other workers over the network.
110   //
111   // XXX: calling this function assumes that we will have NO GPU nodetasks be
112   // executed for the graph_task, the caller of this function need to ensure
113   // this otherwise there will be undefined behaviors. A correct way to fix this
114   // is to re-design the autograd engine so that GPU worker thread to behave the
115   // same as CPU caller thread, record the operation/thread for the device, and
116   // reuse it in backward.
117   // TODO: 1. Add assert in the dist engine to ensure no GPU NodeTasks during
118   // backward
119   //       2. properly setup the thread local ready queue to enable reentrant
120   //       backwards
121   void execute_graph_task_until_ready_queue_empty(
122       torch::autograd::NodeTask&& node_task,
123       bool incrementOutstandingTasks = true);
124 
125   // Run the local autograd engine using the provided graphTask and graphRoot
126   // and accumulate the gradients part 'outputEdges' in the provided autograd
127   // context.
128   c10::intrusive_ptr<c10::ivalue::Future> runEngineAndAccumulateGradients(
129       const ContextPtr& autogradContext,
130       const std::shared_ptr<torch::autograd::Node>& graphRoot,
131       const torch::autograd::edge_list& outputEdges,
132       bool incrementOutStandingTasks = true);
133 
134   // Run after the backward pass is done to appropriately cleanup structures.
135   void cleanupBackwardPass(const ContextPtr& autogradContext);
136 
137   // Global thread to execute CPU continuations.
138   void globalCpuThread(
139       const std::shared_ptr<torch::autograd::ReadyQueue>& ready_queue);
140 
141   // Set of autograd context_ids, which we have already initialized for
142   // distributed autograd on this node (e.g.: already computed dependencies)
143   std::unordered_set<int64_t> initializedContextIds_;
144 
145   mutable std::mutex initializedContextIdsLock_;
146 
147   // Reference to local autograd engine.
148   torch::autograd::Engine& engine_;
149 
150   // Ready queue used by the CPU thread in distributed engine.
151   // See Note [GPU to CPU continuations]
152   std::shared_ptr<torch::autograd::ReadyQueue> global_cpu_ready_queue_;
153 
154   // See Note [GPU to CPU continuations]
155   std::thread global_cpu_thread_;
156 
157   friend class BackwardPassCleanupGuard;
158 };
159 
160 // Guard to clean up resources once the backward pass is done.
161 class BackwardPassCleanupGuard {
162  public:
BackwardPassCleanupGuard(ContextPtr autogradContext)163   explicit BackwardPassCleanupGuard(ContextPtr autogradContext)
164       : autogradContext_(std::move(autogradContext)) {}
165 
~BackwardPassCleanupGuard()166   ~BackwardPassCleanupGuard() {
167     DistEngine::getInstance().cleanupBackwardPass(autogradContext_);
168   }
169 
170  private:
171   ContextPtr autogradContext_;
172 };
173 
174 } // namespace autograd
175 } // namespace distributed
176 } // namespace torch
177