1 #pragma once 2 3 #include <cstdint> 4 #include <functional> 5 6 #include <ATen/core/Dict.h> 7 #include <torch/csrc/autograd/engine.h> 8 #include <torch/csrc/distributed/autograd/functions/recvrpc_backward.h> 9 #include <torch/csrc/distributed/autograd/functions/sendrpc_backward.h> 10 #include <torch/csrc/distributed/rpc/rpc_agent.h> 11 12 namespace torch { 13 namespace distributed { 14 namespace autograd { 15 16 class RecvRpcBackward; 17 18 // DistAutogradContext which stores information for a single distributed 19 // autograd pass on a worker. 20 class TORCH_API DistAutogradContext { 21 public: 22 using GradCallback = std::function<bool(torch::Tensor&)>; 23 24 explicit DistAutogradContext(int64_t contextId); 25 26 // Retrieves the autograd context id for this context. 27 int64_t contextId() const; 28 29 // Records a 'send' autograd function for this context with the provided 30 // message id. 31 void addSendFunction( 32 const std::shared_ptr<SendRpcBackward>& func, 33 int64_t autograd_message_id); 34 35 // Records a 'recv' autograd function for this context with the provided 36 // message id. 37 void addRecvFunction( 38 std::shared_ptr<RecvRpcBackward>& func, 39 int64_t autograd_message_id); 40 41 // Given an autograd_message_id, retrieve the appropriate send function. 42 std::shared_ptr<SendRpcBackward> retrieveSendFunction( 43 int64_t autograd_message_id); 44 45 // Return all send functions for this context. 46 std::unordered_map<int64_t, std::shared_ptr<SendRpcBackward>> sendFunctions() 47 const; 48 49 // Return all recv functions for this context. 50 std::unordered_map<int64_t, std::shared_ptr<RecvRpcBackward>> recvFunctions() 51 const; 52 53 // Adds a future message recording an outstanding RPC. 54 void addOutstandingRpc(const c10::intrusive_ptr<rpc::JitFuture>& jitFuture); 55 56 // Returns all gradients. 57 const c10::Dict<torch::Tensor, torch::Tensor> getGradients() const; 58 59 // This function gives a mutable grad reference to the callback. 60 // If the callback returns true, it means the grad in the context 61 // needs to be updated. 62 void runGradCallbackForVariable( 63 const torch::autograd::Variable& variable, 64 GradCallback&& cb); 65 66 DistAutogradContext(const DistAutogradContext&) = delete; 67 DistAutogradContext& operator=(const DistAutogradContext&) = delete; 68 DistAutogradContext(DistAutogradContext&&) = delete; 69 DistAutogradContext& operator=(DistAutogradContext&&) = delete; 70 71 // records the workerID of a node that we sent an RPC to. 72 // workerIDs are added here when we attach a send function to this autograd 73 // context 74 void addKnownWorkerId(const rpc::worker_id_t workerId); 75 76 // Retrieves a set containing the known workerIds for this context 77 // These are the different workers that this context has sent RPCs to. 78 std::unordered_set<rpc::worker_id_t> getKnownWorkerIds() const; 79 80 private: 81 friend class BackwardPassCleanupGuard; 82 friend class DistEngine; 83 friend class RecvRpcBackward; 84 friend class DistAccumulateGradCaptureHook; 85 86 // Record that we would like to accumulate the provided gradient on the given 87 // variable. 88 void accumulateGrad( 89 const torch::autograd::Variable& variable, 90 const torch::Tensor& grad, 91 size_t num_expected_refs); 92 93 // Retrieve the GraphTask. 94 std::shared_ptr<torch::autograd::GraphTask> retrieveGraphTask(); 95 96 // Set the appropriate graph task for the backward pass. Can be called only 97 // once. 98 void setGraphTask(std::shared_ptr<torch::autograd::GraphTask> graphTask); 99 100 // Resets the graph task to ensure we can run another distributed backward 101 // pass for the same autograd context. 102 void resetGraphTask(); 103 104 // Waits for all outstanding RPCs for this context to finish and clears all 105 // outstanding rpcs held in this context. This should be called only once. 106 c10::intrusive_ptr<c10::ivalue::Future> clearAndWaitForOutstandingRpcsAsync(); 107 108 void clearOutstandingRpcs(); 109 110 // Record an event to mark the completion of gradient computation. These 111 // events will later help to properly synchronize gradients consumptions 112 // in getGradients(). We need these events because backward and 113 // optimizer.step are separate RPC calls, and will occur on different CUDA 114 // streams. Without synchronization, it is possible that gradients are 115 // consumed before they are ready. 116 void recordGradEvent(c10::Device device); 117 118 const int64_t contextId_; 119 120 // Set containing known worker IDs, used in cleaning up autograd context. 121 // Whenever a sendRpcBackward is attached to the autograd graph for this 122 // context, the destination is added here. 123 std::unordered_set<rpc::worker_id_t> knownWorkerIds_; 124 125 // Map from autograd_message_id to appropriate 'send' autograd function. 126 std::unordered_map<int64_t, std::shared_ptr<SendRpcBackward>> 127 sendAutogradFunctions_; 128 129 // Map from autograd_message_id to appropriate 'recv' autograd function. 130 std::unordered_map<int64_t, std::shared_ptr<RecvRpcBackward>> 131 recvAutogradFunctions_; 132 133 // Gradients accumulated in this context so far. The key is the variable on 134 // which the gradient needs to be accumulated and the value is the gradient 135 // that needs to be accumulated on that variable.. 136 c10::Dict<torch::Tensor, torch::Tensor> accumulatedGrads_; 137 138 // See comments for recordGradEvent(c10::Device device); 139 std::unordered_map<c10::Device, c10::Event> gradReadyEvents_; 140 const c10::impl::VirtualGuardImpl impl_; 141 142 // The autograd GraphTask for the backward pass on this node for this context. 143 std::shared_ptr<torch::autograd::GraphTask> graphTask_; 144 145 // List of futures for RPCs initiated by this node to propagate gradients to 146 // other nodes. The distributed autograd engine on this node can return 147 // successfully only if all these futures are done and are successful. 148 std::vector<c10::intrusive_ptr<rpc::JitFuture>> outStandingRpcs_; 149 150 // Lock to protect concurrent modification of the context. 151 mutable std::mutex lock_; 152 }; 153 154 using ContextPtr = std::shared_ptr<DistAutogradContext>; 155 156 // This class stores a shared_ptr to a DistAutogradContext instance in a 157 // thread local variable. The instance is given by the call site. The class 158 // doesn't know the current context. It's just a util class. 159 class TORCH_API ThreadLocalDistAutogradContext { 160 public: 161 // Store 'new_context' to the thread local variable maintained by this class. 162 explicit ThreadLocalDistAutogradContext(ContextPtr&& new_context); 163 ~ThreadLocalDistAutogradContext(); 164 165 // Retrieve the stored DistAutogradContext instance. 166 static ContextPtr getContextPtr(); 167 168 private: 169 ContextPtr prev_context_ptr_; 170 }; 171 172 } // namespace autograd 173 } // namespace distributed 174 } // namespace torch 175