1 #pragma once 2 3 #include <mutex> 4 #include <unordered_set> 5 6 #include <torch/csrc/autograd/engine.h> 7 #include <torch/csrc/autograd/function.h> 8 #include <torch/csrc/autograd/functions/basic_ops.h> 9 #include <torch/csrc/distributed/autograd/context/context.h> 10 11 namespace torch { 12 namespace distributed { 13 namespace autograd { 14 15 // Forward declaration. 16 class BackwardPassCleanupGuard; 17 18 // This is a singleton class responsible for running distributed backward 19 // passes. This engine relies heavily on the vanilla autograd engine and tries 20 // to re-use it as much as possible. This class is mostly responsible for the 21 // distributed aspects of autograd and tries to hook into the autograd engine 22 // where convenient. 23 24 // Unlike the vanilla autograd engine, the distributed autograd engine 25 // accumulates the gradients in the appropriate DistAutogradContext. This avoids 26 // multiple trainer nodes stomping on each others gradients. 27 class TORCH_API DistEngine { 28 public: 29 // Retrieve the singleton instance. 30 static DistEngine& getInstance(); 31 32 // Given a list of root variables, start the distributed backwards pass from 33 // these variables and accumulate all the gradients in the current autograd 34 // context on each node. This method is used to kickoff distributed autograd 35 // on a single node. 36 void execute( 37 int64_t context_id, 38 const torch::autograd::variable_list& roots, 39 bool retainGraph); 40 41 // Given a send function to execute in the autograd engine, ensures we compute 42 // dependencies once for this node and enqueues the send function for execute 43 // in the engine. 44 // This method is used to kick off the autograd computation on a node when it 45 // receives gradients from the corresponding 'recv' method on another node. 46 // The gradients are accumulated in the provided autograd context. 47 c10::intrusive_ptr<c10::ivalue::Future> executeSendFunctionAsync( 48 const ContextPtr& autogradContext, 49 const std::shared_ptr<SendRpcBackward>& sendFunction, 50 bool retainGraph); 51 52 // Number of backward passes currently running for the Distributed Engine. 53 size_t numBackwardPasses() const; 54 55 // Returns key-value pairs consisting of useful debugging information related 56 // to distributed autograd. 57 std::unordered_map<std::string, int> getDebugInfo() const; 58 59 DistEngine(const DistEngine&) = delete; 60 DistEngine& operator=(const DistEngine&) = delete; 61 DistEngine(DistEngine&&) = delete; 62 DistEngine& operator=(DistEngine&&) = delete; 63 64 private: 65 // Make sure this is a singleton. 66 DistEngine(); 67 ~DistEngine(); 68 69 // Validates the input roots for the backward computations and retrieves the 70 // appropriate root edges and corresponding gradients. Populates root_edges 71 // with the appropriate gradient edges and grads with the gradients for each 72 // edge. 73 void validateRootsAndRetrieveEdges( 74 const torch::autograd::variable_list& roots, 75 torch::autograd::edge_list& rootEdges, 76 torch::autograd::variable_list& grads); 77 78 // Given the autograd context, root edges and grads, we compute dependencies 79 // for the local node and fill out the provided GraphTask and GraphRoot with 80 // appropriate information for the local autograd engine. 81 // We also determine all leaf nodes(functions) in the graph and accumulate 82 // them in outputEdges. 83 void computeDependencies( 84 const ContextPtr& context, 85 const torch::autograd::edge_list& rootEdges, 86 const torch::autograd::variable_list& grads, 87 const std::shared_ptr<torch::autograd::Node>& graphRoot, 88 torch::autograd::edge_list& outputEdges, 89 bool retainGraph); 90 91 // Given a pre-populated GraphTask and a root node, compute the backward pass 92 // for the autograd graph until the graph task ready queue is empty. 93 // 94 // This method assumes that the appropriate GraphTask has already been 95 // initialized appropriately. It will construct a local ready queue to 96 // traverse the GraphTask instead of using the GraphTask embedded 97 // cpu_ready_queue, this is because dist engine might run the same GraphTask 98 // from different SendFunctions concurrently in different threads. The method 99 // will only mark the GraphTask as completed when it needs to, which means it 100 // might not mark as completed for every call as dist engine would like to 101 // keep the GraphTask alive when it not receives all gradients. 102 // 103 // When `incrementOutstandingTasks=false`, the function does not increment 104 // 'outstanding_tasks_' in the appropriate GraphTask. It is assumed we've 105 // already done this before hand for this task (to ensure we don't pre-mark 106 // this graph_task as completed). This is useful in the distributed autograd 107 // case where we need to increment 'outstanding_tasks_' first to indicate the 108 // local autograd engine the graph task is not completed until it receives the 109 // signals from other workers over the network. 110 // 111 // XXX: calling this function assumes that we will have NO GPU nodetasks be 112 // executed for the graph_task, the caller of this function need to ensure 113 // this otherwise there will be undefined behaviors. A correct way to fix this 114 // is to re-design the autograd engine so that GPU worker thread to behave the 115 // same as CPU caller thread, record the operation/thread for the device, and 116 // reuse it in backward. 117 // TODO: 1. Add assert in the dist engine to ensure no GPU NodeTasks during 118 // backward 119 // 2. properly setup the thread local ready queue to enable reentrant 120 // backwards 121 void execute_graph_task_until_ready_queue_empty( 122 torch::autograd::NodeTask&& node_task, 123 bool incrementOutstandingTasks = true); 124 125 // Run the local autograd engine using the provided graphTask and graphRoot 126 // and accumulate the gradients part 'outputEdges' in the provided autograd 127 // context. 128 c10::intrusive_ptr<c10::ivalue::Future> runEngineAndAccumulateGradients( 129 const ContextPtr& autogradContext, 130 const std::shared_ptr<torch::autograd::Node>& graphRoot, 131 const torch::autograd::edge_list& outputEdges, 132 bool incrementOutStandingTasks = true); 133 134 // Run after the backward pass is done to appropriately cleanup structures. 135 void cleanupBackwardPass(const ContextPtr& autogradContext); 136 137 // Global thread to execute CPU continuations. 138 void globalCpuThread( 139 const std::shared_ptr<torch::autograd::ReadyQueue>& ready_queue); 140 141 // Set of autograd context_ids, which we have already initialized for 142 // distributed autograd on this node (e.g.: already computed dependencies) 143 std::unordered_set<int64_t> initializedContextIds_; 144 145 mutable std::mutex initializedContextIdsLock_; 146 147 // Reference to local autograd engine. 148 torch::autograd::Engine& engine_; 149 150 // Ready queue used by the CPU thread in distributed engine. 151 // See Note [GPU to CPU continuations] 152 std::shared_ptr<torch::autograd::ReadyQueue> global_cpu_ready_queue_; 153 154 // See Note [GPU to CPU continuations] 155 std::thread global_cpu_thread_; 156 157 friend class BackwardPassCleanupGuard; 158 }; 159 160 // Guard to clean up resources once the backward pass is done. 161 class BackwardPassCleanupGuard { 162 public: BackwardPassCleanupGuard(ContextPtr autogradContext)163 explicit BackwardPassCleanupGuard(ContextPtr autogradContext) 164 : autogradContext_(std::move(autogradContext)) {} 165 ~BackwardPassCleanupGuard()166 ~BackwardPassCleanupGuard() { 167 DistEngine::getInstance().cleanupBackwardPass(autogradContext_); 168 } 169 170 private: 171 ContextPtr autogradContext_; 172 }; 173 174 } // namespace autograd 175 } // namespace distributed 176 } // namespace torch 177