1 #include <torch/csrc/distributed/autograd/context/context.h>
2
3 #include <functional>
4
5 #include <c10/core/StreamGuard.h>
6 #include <c10/util/Exception.h>
7 #include <torch/csrc/autograd/functions/accumulate_grad.h>
8
9 namespace torch {
10 namespace distributed {
11 namespace autograd {
12
13 using torch::autograd::AccumulateGrad;
14
DistAutogradContext(int64_t contextId)15 DistAutogradContext::DistAutogradContext(int64_t contextId)
16 : contextId_(contextId),
17 impl_(c10::impl::VirtualGuardImpl{
18 at::hasCUDA() ? c10::DeviceType::CUDA : c10::DeviceType::CPU}) {}
19
contextId() const20 int64_t DistAutogradContext::contextId() const {
21 return contextId_;
22 }
23
getKnownWorkerIds() const24 std::unordered_set<rpc::worker_id_t> DistAutogradContext::getKnownWorkerIds()
25 const {
26 std::lock_guard<std::mutex> guard(lock_);
27 return knownWorkerIds_;
28 };
29
addKnownWorkerId(const rpc::worker_id_t workerId)30 void DistAutogradContext::addKnownWorkerId(const rpc::worker_id_t workerId) {
31 std::lock_guard<std::mutex> guard(lock_);
32 knownWorkerIds_.insert(workerId);
33 }
34
addSendFunction(const std::shared_ptr<SendRpcBackward> & func,int64_t autograd_message_id)35 void DistAutogradContext::addSendFunction(
36 const std::shared_ptr<SendRpcBackward>& func,
37 int64_t autograd_message_id) {
38 TORCH_INTERNAL_ASSERT(func != nullptr);
39
40 std::lock_guard<std::mutex> guard(lock_);
41 TORCH_INTERNAL_ASSERT(
42 sendAutogradFunctions_.find(autograd_message_id) ==
43 sendAutogradFunctions_.end());
44 sendAutogradFunctions_.emplace(autograd_message_id, func);
45 }
46
addRecvFunction(std::shared_ptr<RecvRpcBackward> & func,int64_t autograd_message_id)47 void DistAutogradContext::addRecvFunction(
48 std::shared_ptr<RecvRpcBackward>& func,
49 int64_t autograd_message_id) {
50 TORCH_INTERNAL_ASSERT(func != nullptr);
51
52 std::lock_guard<std::mutex> guard(lock_);
53 TORCH_INTERNAL_ASSERT(
54 recvAutogradFunctions_.find(autograd_message_id) ==
55 recvAutogradFunctions_.end());
56 recvAutogradFunctions_.emplace(autograd_message_id, func);
57 }
58
59 std::unordered_map<int64_t, std::shared_ptr<SendRpcBackward>>
sendFunctions() const60 DistAutogradContext::sendFunctions() const {
61 std::lock_guard<std::mutex> guard(lock_);
62 return sendAutogradFunctions_;
63 }
64
65 std::unordered_map<int64_t, std::shared_ptr<RecvRpcBackward>>
recvFunctions() const66 DistAutogradContext::recvFunctions() const {
67 std::lock_guard<std::mutex> guard(lock_);
68 return recvAutogradFunctions_;
69 }
70
accumulateGrad(const torch::autograd::Variable & variable,const torch::Tensor & grad,size_t num_expected_refs)71 void DistAutogradContext::accumulateGrad(
72 const torch::autograd::Variable& variable,
73 const torch::Tensor& grad,
74 size_t num_expected_refs) {
75 TORCH_INTERNAL_ASSERT(grad.defined());
76 TORCH_INTERNAL_ASSERT(variable.requires_grad());
77
78 std::lock_guard<std::mutex> guard(lock_);
79 auto it = accumulatedGrads_.find(variable);
80 at::Tensor old_grad;
81 if (it != accumulatedGrads_.end()) {
82 // Accumulate multiple grads on the same variable.
83 old_grad = it->value();
84 }
85
86 // Gradients are computed using the forward streams. Local autograd
87 // engine uses AccumulateGrad function to retrieve and apply forward
88 // stream during the backward computation. In distributed autograd,
89 // we directly call AccumulateGrad::accumulateGrad, and skip the
90 // CUDA stream restoration from autograd function. Hence, we manually
91 // call it here to get the streams correct.
92 auto forward_stream =
93 torch::autograd::impl::grad_accumulator(variable)->stream();
94 c10::OptionalStreamGuard stream_guard(forward_stream);
95
96 // No higher order gradients supported in distributed autograd.
97 AutoGradMode grad_mode(false);
98
99 // TODO: Need to bump 'num_expected_refs' here when we support post_hooks for
100 // distributed autograd as part of
101 // https://github.com/pytorch/pytorch/issues/33482
102 AccumulateGrad::accumulateGrad(
103 variable,
104 old_grad,
105 grad,
106 num_expected_refs,
107 [this, &variable](at::Tensor&& grad_update) {
108 auto device = grad_update.device();
109 accumulatedGrads_.insert(variable, std::move(grad_update));
110 recordGradEvent(device);
111 });
112 }
113
114 std::shared_ptr<torch::autograd::GraphTask> DistAutogradContext::
retrieveGraphTask()115 retrieveGraphTask() {
116 std::lock_guard<std::mutex> guard(lock_);
117 TORCH_INTERNAL_ASSERT(graphTask_);
118 return graphTask_;
119 }
120
setGraphTask(std::shared_ptr<torch::autograd::GraphTask> graphTask)121 void DistAutogradContext::setGraphTask(
122 std::shared_ptr<torch::autograd::GraphTask> graphTask) {
123 std::lock_guard<std::mutex> guard(lock_);
124 TORCH_INTERNAL_ASSERT(
125 !graphTask_,
126 "Cannot set GraphTask multiple times for the same autograd context");
127 graphTask_ = std::move(graphTask);
128 }
129
resetGraphTask()130 void DistAutogradContext::resetGraphTask() {
131 std::lock_guard<std::mutex> guard(lock_);
132 graphTask_ = nullptr;
133 }
134
addOutstandingRpc(const c10::intrusive_ptr<rpc::JitFuture> & jitFuture)135 void DistAutogradContext::addOutstandingRpc(
136 const c10::intrusive_ptr<rpc::JitFuture>& jitFuture) {
137 jitFuture->addCallback([this](rpc::JitFuture& future) {
138 if (future.hasError()) {
139 // If we have an error, let the local autograd engine know about it.
140 std::unique_lock<std::mutex> lock(lock_);
141 if (graphTask_) {
142 graphTask_->set_exception_without_signal(nullptr);
143 lock.unlock();
144 if (!graphTask_->future_completed_.exchange(true)) {
145 graphTask_->future_result_->setErrorIfNeeded(future.exception_ptr());
146 }
147 } else {
148 LOG(WARNING) << "Ignoring error since GraphTask is no longer valid: "
149 << future.tryRetrieveErrorMessage();
150 }
151 }
152 });
153 std::lock_guard<std::mutex> guard(lock_);
154 outStandingRpcs_.push_back(jitFuture);
155 }
156
clearOutstandingRpcs()157 void DistAutogradContext::clearOutstandingRpcs() {
158 std::unique_lock<std::mutex> lock(lock_);
159 outStandingRpcs_.clear();
160 }
161
recordGradEvent(c10::Device device)162 void DistAutogradContext::recordGradEvent(c10::Device device) {
163 if (device.is_cuda()) {
164 auto iter = gradReadyEvents_.find(device);
165 if (iter == gradReadyEvents_.end()) {
166 c10::Event event(device.type());
167 event.record(impl_.getStream(event.device()));
168 gradReadyEvents_.emplace(
169 std::piecewise_construct,
170 std::forward_as_tuple(device),
171 std::forward_as_tuple(std::move(event)));
172 } else {
173 iter->second.record(impl_.getStream(device));
174 }
175 }
176 }
177
178 c10::intrusive_ptr<c10::ivalue::Future> DistAutogradContext::
clearAndWaitForOutstandingRpcsAsync()179 clearAndWaitForOutstandingRpcsAsync() {
180 std::unique_lock<std::mutex> lock(lock_);
181 auto outStandingRpcs = std::move(outStandingRpcs_);
182 lock.unlock();
183
184 struct State {
185 explicit State(int32_t count)
186 : future(
187 c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get())),
188 remaining(count) {}
189 c10::intrusive_ptr<c10::ivalue::Future> future;
190 std::atomic<int32_t> remaining;
191 std::atomic<bool> alreadySentError{false};
192 };
193 auto state = std::make_shared<State>(outStandingRpcs.size());
194 if (outStandingRpcs.empty()) {
195 state->future->markCompleted(c10::IValue());
196 } else {
197 for (auto& rpc : outStandingRpcs) {
198 rpc->addCallback([state](rpc::JitFuture& future) {
199 if (future.hasError()) {
200 // If there's an error, we want to setError() on the future,
201 // unless another error has already been sent - use a CAS to
202 // guard.
203 //
204 // Don't decrement num remaining here! (We don't need to, since
205 // memory handling is separate). If we simply don't decrement on
206 // errors, reaching 0 means that there were no errors - and hence,
207 // we can just markCompleted() without any other checking there.
208 bool expectedAlreadySent = false;
209 if (state->alreadySentError.compare_exchange_strong(
210 expectedAlreadySent, true)) {
211 state->future->setError(future.exception_ptr());
212 }
213 return;
214 }
215
216 if (--state->remaining == 0) {
217 state->future->markCompleted(c10::IValue());
218 }
219 });
220 }
221 }
222 return state->future;
223 }
224
retrieveSendFunction(int64_t autograd_message_id)225 std::shared_ptr<SendRpcBackward> DistAutogradContext::retrieveSendFunction(
226 int64_t autograd_message_id) {
227 std::lock_guard<std::mutex> guard(lock_);
228 auto it = sendAutogradFunctions_.find(autograd_message_id);
229 TORCH_CHECK(
230 it != sendAutogradFunctions_.end(),
231 "Could not find send function for autograd message id: ",
232 autograd_message_id);
233 return it->second;
234 }
235
236 const c10::Dict<torch::Tensor, torch::Tensor> DistAutogradContext::
getGradients() const237 getGradients() const {
238 std::lock_guard<std::mutex> guard(lock_);
239 // block current streams before accessing gradients to make sure that
240 // gradient computations are finished before use.
241 for (auto& entry : gradReadyEvents_) {
242 auto& event = entry.second;
243 event.block(impl_.getStream(event.device()));
244 }
245 return accumulatedGrads_;
246 }
247
runGradCallbackForVariable(const torch::autograd::Variable & variable,GradCallback && cb)248 void DistAutogradContext::runGradCallbackForVariable(
249 const torch::autograd::Variable& variable,
250 GradCallback&& cb) {
251 torch::Tensor grad;
252 {
253 std::lock_guard<std::mutex> guard(lock_);
254 auto it = accumulatedGrads_.find(variable);
255 TORCH_INTERNAL_ASSERT(
256 it != accumulatedGrads_.end(),
257 "The grad for the variable should exist in dist_autograd context.");
258 grad = it->value();
259 }
260 if (cb(grad)) {
261 std::lock_guard<std::mutex> guard(lock_);
262 auto device = grad.device();
263 // Needs to update the grad in the map.
264 accumulatedGrads_.insert_or_assign(variable, std::move(grad));
265 recordGradEvent(device);
266 }
267 }
268
269 namespace {
270 thread_local ContextPtr tl_context_ptr;
271 } // namespace
272
ThreadLocalDistAutogradContext(ContextPtr && new_context)273 ThreadLocalDistAutogradContext::ThreadLocalDistAutogradContext(
274 ContextPtr&& new_context)
275 : prev_context_ptr_(std::move(tl_context_ptr)) {
276 tl_context_ptr = std::move(new_context);
277 }
278
~ThreadLocalDistAutogradContext()279 ThreadLocalDistAutogradContext::~ThreadLocalDistAutogradContext() {
280 tl_context_ptr = std::move(prev_context_ptr_);
281 }
282
283 // static
getContextPtr()284 ContextPtr ThreadLocalDistAutogradContext::getContextPtr() {
285 return tl_context_ptr;
286 }
287
288 } // namespace autograd
289 } // namespace distributed
290 } // namespace torch
291