xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/context/context.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/autograd/context/context.h>
2 
3 #include <functional>
4 
5 #include <c10/core/StreamGuard.h>
6 #include <c10/util/Exception.h>
7 #include <torch/csrc/autograd/functions/accumulate_grad.h>
8 
9 namespace torch {
10 namespace distributed {
11 namespace autograd {
12 
13 using torch::autograd::AccumulateGrad;
14 
DistAutogradContext(int64_t contextId)15 DistAutogradContext::DistAutogradContext(int64_t contextId)
16     : contextId_(contextId),
17       impl_(c10::impl::VirtualGuardImpl{
18           at::hasCUDA() ? c10::DeviceType::CUDA : c10::DeviceType::CPU}) {}
19 
contextId() const20 int64_t DistAutogradContext::contextId() const {
21   return contextId_;
22 }
23 
getKnownWorkerIds() const24 std::unordered_set<rpc::worker_id_t> DistAutogradContext::getKnownWorkerIds()
25     const {
26   std::lock_guard<std::mutex> guard(lock_);
27   return knownWorkerIds_;
28 };
29 
addKnownWorkerId(const rpc::worker_id_t workerId)30 void DistAutogradContext::addKnownWorkerId(const rpc::worker_id_t workerId) {
31   std::lock_guard<std::mutex> guard(lock_);
32   knownWorkerIds_.insert(workerId);
33 }
34 
addSendFunction(const std::shared_ptr<SendRpcBackward> & func,int64_t autograd_message_id)35 void DistAutogradContext::addSendFunction(
36     const std::shared_ptr<SendRpcBackward>& func,
37     int64_t autograd_message_id) {
38   TORCH_INTERNAL_ASSERT(func != nullptr);
39 
40   std::lock_guard<std::mutex> guard(lock_);
41   TORCH_INTERNAL_ASSERT(
42       sendAutogradFunctions_.find(autograd_message_id) ==
43       sendAutogradFunctions_.end());
44   sendAutogradFunctions_.emplace(autograd_message_id, func);
45 }
46 
addRecvFunction(std::shared_ptr<RecvRpcBackward> & func,int64_t autograd_message_id)47 void DistAutogradContext::addRecvFunction(
48     std::shared_ptr<RecvRpcBackward>& func,
49     int64_t autograd_message_id) {
50   TORCH_INTERNAL_ASSERT(func != nullptr);
51 
52   std::lock_guard<std::mutex> guard(lock_);
53   TORCH_INTERNAL_ASSERT(
54       recvAutogradFunctions_.find(autograd_message_id) ==
55       recvAutogradFunctions_.end());
56   recvAutogradFunctions_.emplace(autograd_message_id, func);
57 }
58 
59 std::unordered_map<int64_t, std::shared_ptr<SendRpcBackward>>
sendFunctions() const60 DistAutogradContext::sendFunctions() const {
61   std::lock_guard<std::mutex> guard(lock_);
62   return sendAutogradFunctions_;
63 }
64 
65 std::unordered_map<int64_t, std::shared_ptr<RecvRpcBackward>>
recvFunctions() const66 DistAutogradContext::recvFunctions() const {
67   std::lock_guard<std::mutex> guard(lock_);
68   return recvAutogradFunctions_;
69 }
70 
accumulateGrad(const torch::autograd::Variable & variable,const torch::Tensor & grad,size_t num_expected_refs)71 void DistAutogradContext::accumulateGrad(
72     const torch::autograd::Variable& variable,
73     const torch::Tensor& grad,
74     size_t num_expected_refs) {
75   TORCH_INTERNAL_ASSERT(grad.defined());
76   TORCH_INTERNAL_ASSERT(variable.requires_grad());
77 
78   std::lock_guard<std::mutex> guard(lock_);
79   auto it = accumulatedGrads_.find(variable);
80   at::Tensor old_grad;
81   if (it != accumulatedGrads_.end()) {
82     // Accumulate multiple grads on the same variable.
83     old_grad = it->value();
84   }
85 
86   // Gradients are computed using the forward streams. Local autograd
87   // engine uses AccumulateGrad function to retrieve and apply forward
88   // stream during the backward computation. In distributed autograd,
89   // we directly call AccumulateGrad::accumulateGrad, and skip the
90   // CUDA stream restoration from autograd function. Hence, we manually
91   // call it here to get the streams correct.
92   auto forward_stream =
93       torch::autograd::impl::grad_accumulator(variable)->stream();
94   c10::OptionalStreamGuard stream_guard(forward_stream);
95 
96   // No higher order gradients supported in distributed autograd.
97   AutoGradMode grad_mode(false);
98 
99   // TODO: Need to bump 'num_expected_refs' here when we support post_hooks for
100   // distributed autograd as part of
101   // https://github.com/pytorch/pytorch/issues/33482
102   AccumulateGrad::accumulateGrad(
103       variable,
104       old_grad,
105       grad,
106       num_expected_refs,
107       [this, &variable](at::Tensor&& grad_update) {
108         auto device = grad_update.device();
109         accumulatedGrads_.insert(variable, std::move(grad_update));
110         recordGradEvent(device);
111       });
112 }
113 
114 std::shared_ptr<torch::autograd::GraphTask> DistAutogradContext::
retrieveGraphTask()115     retrieveGraphTask() {
116   std::lock_guard<std::mutex> guard(lock_);
117   TORCH_INTERNAL_ASSERT(graphTask_);
118   return graphTask_;
119 }
120 
setGraphTask(std::shared_ptr<torch::autograd::GraphTask> graphTask)121 void DistAutogradContext::setGraphTask(
122     std::shared_ptr<torch::autograd::GraphTask> graphTask) {
123   std::lock_guard<std::mutex> guard(lock_);
124   TORCH_INTERNAL_ASSERT(
125       !graphTask_,
126       "Cannot set GraphTask multiple times for the same autograd context");
127   graphTask_ = std::move(graphTask);
128 }
129 
resetGraphTask()130 void DistAutogradContext::resetGraphTask() {
131   std::lock_guard<std::mutex> guard(lock_);
132   graphTask_ = nullptr;
133 }
134 
addOutstandingRpc(const c10::intrusive_ptr<rpc::JitFuture> & jitFuture)135 void DistAutogradContext::addOutstandingRpc(
136     const c10::intrusive_ptr<rpc::JitFuture>& jitFuture) {
137   jitFuture->addCallback([this](rpc::JitFuture& future) {
138     if (future.hasError()) {
139       // If we have an error, let the local autograd engine know about it.
140       std::unique_lock<std::mutex> lock(lock_);
141       if (graphTask_) {
142         graphTask_->set_exception_without_signal(nullptr);
143         lock.unlock();
144         if (!graphTask_->future_completed_.exchange(true)) {
145           graphTask_->future_result_->setErrorIfNeeded(future.exception_ptr());
146         }
147       } else {
148         LOG(WARNING) << "Ignoring error since GraphTask is no longer valid: "
149                      << future.tryRetrieveErrorMessage();
150       }
151     }
152   });
153   std::lock_guard<std::mutex> guard(lock_);
154   outStandingRpcs_.push_back(jitFuture);
155 }
156 
clearOutstandingRpcs()157 void DistAutogradContext::clearOutstandingRpcs() {
158   std::unique_lock<std::mutex> lock(lock_);
159   outStandingRpcs_.clear();
160 }
161 
recordGradEvent(c10::Device device)162 void DistAutogradContext::recordGradEvent(c10::Device device) {
163   if (device.is_cuda()) {
164     auto iter = gradReadyEvents_.find(device);
165     if (iter == gradReadyEvents_.end()) {
166       c10::Event event(device.type());
167       event.record(impl_.getStream(event.device()));
168       gradReadyEvents_.emplace(
169           std::piecewise_construct,
170           std::forward_as_tuple(device),
171           std::forward_as_tuple(std::move(event)));
172     } else {
173       iter->second.record(impl_.getStream(device));
174     }
175   }
176 }
177 
178 c10::intrusive_ptr<c10::ivalue::Future> DistAutogradContext::
clearAndWaitForOutstandingRpcsAsync()179     clearAndWaitForOutstandingRpcsAsync() {
180   std::unique_lock<std::mutex> lock(lock_);
181   auto outStandingRpcs = std::move(outStandingRpcs_);
182   lock.unlock();
183 
184   struct State {
185     explicit State(int32_t count)
186         : future(
187               c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get())),
188           remaining(count) {}
189     c10::intrusive_ptr<c10::ivalue::Future> future;
190     std::atomic<int32_t> remaining;
191     std::atomic<bool> alreadySentError{false};
192   };
193   auto state = std::make_shared<State>(outStandingRpcs.size());
194   if (outStandingRpcs.empty()) {
195     state->future->markCompleted(c10::IValue());
196   } else {
197     for (auto& rpc : outStandingRpcs) {
198       rpc->addCallback([state](rpc::JitFuture& future) {
199         if (future.hasError()) {
200           // If there's an error, we want to setError() on the future,
201           // unless another error has already been sent - use a CAS to
202           // guard.
203           //
204           // Don't decrement num remaining here! (We don't need to, since
205           // memory handling is separate). If we simply don't decrement on
206           // errors, reaching 0 means that there were no errors - and hence,
207           // we can just markCompleted() without any other checking there.
208           bool expectedAlreadySent = false;
209           if (state->alreadySentError.compare_exchange_strong(
210                   expectedAlreadySent, true)) {
211             state->future->setError(future.exception_ptr());
212           }
213           return;
214         }
215 
216         if (--state->remaining == 0) {
217           state->future->markCompleted(c10::IValue());
218         }
219       });
220     }
221   }
222   return state->future;
223 }
224 
retrieveSendFunction(int64_t autograd_message_id)225 std::shared_ptr<SendRpcBackward> DistAutogradContext::retrieveSendFunction(
226     int64_t autograd_message_id) {
227   std::lock_guard<std::mutex> guard(lock_);
228   auto it = sendAutogradFunctions_.find(autograd_message_id);
229   TORCH_CHECK(
230       it != sendAutogradFunctions_.end(),
231       "Could not find send function for autograd message id: ",
232       autograd_message_id);
233   return it->second;
234 }
235 
236 const c10::Dict<torch::Tensor, torch::Tensor> DistAutogradContext::
getGradients() const237     getGradients() const {
238   std::lock_guard<std::mutex> guard(lock_);
239   // block current streams before accessing gradients to make sure that
240   // gradient computations are finished before use.
241   for (auto& entry : gradReadyEvents_) {
242     auto& event = entry.second;
243     event.block(impl_.getStream(event.device()));
244   }
245   return accumulatedGrads_;
246 }
247 
runGradCallbackForVariable(const torch::autograd::Variable & variable,GradCallback && cb)248 void DistAutogradContext::runGradCallbackForVariable(
249     const torch::autograd::Variable& variable,
250     GradCallback&& cb) {
251   torch::Tensor grad;
252   {
253     std::lock_guard<std::mutex> guard(lock_);
254     auto it = accumulatedGrads_.find(variable);
255     TORCH_INTERNAL_ASSERT(
256         it != accumulatedGrads_.end(),
257         "The grad for the variable should exist in dist_autograd context.");
258     grad = it->value();
259   }
260   if (cb(grad)) {
261     std::lock_guard<std::mutex> guard(lock_);
262     auto device = grad.device();
263     // Needs to update the grad in the map.
264     accumulatedGrads_.insert_or_assign(variable, std::move(grad));
265     recordGradEvent(device);
266   }
267 }
268 
269 namespace {
270 thread_local ContextPtr tl_context_ptr;
271 } // namespace
272 
ThreadLocalDistAutogradContext(ContextPtr && new_context)273 ThreadLocalDistAutogradContext::ThreadLocalDistAutogradContext(
274     ContextPtr&& new_context)
275     : prev_context_ptr_(std::move(tl_context_ptr)) {
276   tl_context_ptr = std::move(new_context);
277 }
278 
~ThreadLocalDistAutogradContext()279 ThreadLocalDistAutogradContext::~ThreadLocalDistAutogradContext() {
280   tl_context_ptr = std::move(prev_context_ptr_);
281 }
282 
283 // static
getContextPtr()284 ContextPtr ThreadLocalDistAutogradContext::getContextPtr() {
285   return tl_context_ptr;
286 }
287 
288 } // namespace autograd
289 } // namespace distributed
290 } // namespace torch
291