1 #include <queue>
2
3 #include <ATen/Parallel.h>
4 #include <c10/core/Event.h>
5 #include <c10/util/DeadlockDetection.h>
6 #include <c10/util/irange.h>
7 #include <c10/util/thread_name.h>
8 #include <torch/csrc/autograd/functions/accumulate_grad.h>
9 #include <torch/csrc/autograd/input_buffer.h>
10 #include <torch/csrc/distributed/autograd/context/container.h>
11 #include <torch/csrc/distributed/autograd/engine/dist_engine.h>
12
13 namespace torch {
14 namespace distributed {
15 namespace autograd {
16
17 using torch::autograd::AccumulateGrad;
18 using torch::autograd::edge_list;
19 using torch::autograd::Engine;
20 using torch::autograd::GraphRoot;
21 using torch::autograd::GraphTask;
22 using torch::autograd::GraphTaskGuard;
23 using torch::autograd::InputBuffer;
24 using torch::autograd::Node;
25 using torch::autograd::NodeTask;
26 using torch::autograd::ReadyQueue;
27 using torch::autograd::validate_outputs;
28 using torch::autograd::variable_list;
29
30 static constexpr const char* kNumBackwardPasses = "num_current_backward_passes";
31 static constexpr const char* kNumAutogradContexts = "num_autograd_contexts";
32
33 // This hook does 3 things:
34 // 1. Call pre hooks of the original AccumulateGrad to modify the input grad.
35 // 2. Accumurate the guard to RPC context.
36 // 3. Call post hooks of the original AccumulateGrad.
37 class DistAccumulateGradCaptureHook
38 : public GraphTask::ExecInfo::Capture::GradCaptureHook {
39 public:
DistAccumulateGradCaptureHook(std::shared_ptr<AccumulateGrad> accumulateGrad,ContextPtr autogradContext)40 DistAccumulateGradCaptureHook(
41 std::shared_ptr<AccumulateGrad> accumulateGrad,
42 ContextPtr autogradContext)
43 : accumulateGrad_(std::move(accumulateGrad)),
44 autogradContext_(std::move(autogradContext)) {}
45
operator ()(const at::Tensor & grad)46 at::Tensor operator()(const at::Tensor& grad) override {
47 ThreadLocalDistAutogradContext contextGuard{ContextPtr(autogradContext_)};
48 variable_list inputGrads = {grad};
49 // It's intended that pre/post hooks are still called even if the grad is
50 // undefined here.
51 for (const auto& hook : accumulateGrad_->pre_hooks()) {
52 inputGrads = (*hook)(inputGrads);
53 }
54 // It is possible that the grad is not defined since a separate
55 // invocation of the autograd engine on the same node might actually
56 // compute this gradient.
57 if (inputGrads[0].defined()) {
58 // There are 3 internal references to 'inputGrads[0]' at this moment:
59 // 1. 'inputGrads[0]' in this function.
60 // 2. 'graph_task->captured_vars_' on the callsite in the local engine.
61 // 3. 'InputBuffer& inputs' on the callsite as the inputs of the
62 // function node.
63 autogradContext_->accumulateGrad(
64 accumulateGrad_->variable, inputGrads[0], 3 /* num_expected_refs */);
65 }
66 const variable_list kEmptyOutput;
67 for (const auto& hook : accumulateGrad_->post_hooks()) {
68 (*hook)(kEmptyOutput, inputGrads);
69 }
70 return inputGrads[0];
71 }
72
73 private:
74 std::shared_ptr<AccumulateGrad> accumulateGrad_;
75 ContextPtr autogradContext_;
76 };
77
globalCpuThread(const std::shared_ptr<ReadyQueue> & ready_queue)78 void DistEngine::globalCpuThread(
79 const std::shared_ptr<ReadyQueue>& ready_queue) {
80 c10::setThreadName("pt_dist_engine");
81 while (true) {
82 NodeTask task = ready_queue->pop();
83 if (task.isShutdownTask_) {
84 // Need to shutdown this thread.
85 C10_LOG_API_USAGE_ONCE("torch.autograd.thread_shutdown");
86 break;
87 }
88
89 auto graphTask = task.base_.lock();
90 if (graphTask == nullptr) {
91 // GraphTask has expired, ignore and continue processing.
92 continue;
93 }
94
95 // Launch the execution on a JIT thread.
96 at::launch([this,
97 graphTask,
98 graphRoot = task.fn_,
99 variables =
100 InputBuffer::variables(std::move(task.inputs_))]() mutable {
101 InputBuffer inputs(variables.size());
102 for (const auto i : c10::irange(variables.size())) {
103 inputs.add(i, std::move(variables[i]), std::nullopt, std::nullopt);
104 }
105 execute_graph_task_until_ready_queue_empty(
106 /*node_task*/ NodeTask(graphTask, graphRoot, std::move(inputs)),
107 /*incrementOutstandingTasks*/ false);
108 });
109 }
110 }
111
DistEngine()112 DistEngine::DistEngine()
113 : initializedContextIds_(),
114 engine_(Engine::get_default_engine()),
115 global_cpu_ready_queue_(std::make_shared<ReadyQueue>()),
116 global_cpu_thread_(
117 &DistEngine::globalCpuThread,
118 this,
119 global_cpu_ready_queue_) {
120 // Note [GPU to CPU continuations]
121 // ~~~~~~~~~~~~~~~~~~~~~~~~~~
122 // Initialize a single CPU thread to execute continuations from GPU
123 // tasks. The multithreaded structure for the distributed engine works
124 // well only for CPU tasks. If we have an order of tasks like
125 // CPU->GPU->CPU, distributed autograd has no thread to execute the last
126 // CPU task on. To fix this, we introduce a global CPU thread to handle
127 // such situations and it will be responsible for executing these CPU
128 // tasks. The CPU thread has its own ready_queue which is used as the
129 // cpu_ready_queue for all GraphTasks for DistEngine. This ensures all GPU
130 // to CPU continuations are enqueued on this thread. The global CPU thread
131 // simply dequeues tasks from the global queue and calls
132 // "execute_graph_task_until_ready_queue_empty" on a JIT thread to execute the
133 // appropriate task.
134 global_cpu_thread_.detach();
135 }
136
~DistEngine()137 DistEngine::~DistEngine() {
138 // Ensure we shutdown the CPU thread.
139 TORCH_ASSERT_NO_GIL_WITHOUT_PYTHON_DEP();
140 global_cpu_ready_queue_->pushShutdownTask();
141 global_cpu_thread_.join();
142 }
143
getInstance()144 DistEngine& DistEngine::getInstance() {
145 // Leaky singleton to avoid module destructor race.
146 static DistEngine* engine = new DistEngine();
147 return *engine;
148 }
149
validateRootsAndRetrieveEdges(const variable_list & roots,edge_list & rootEdges,variable_list & grads)150 void DistEngine::validateRootsAndRetrieveEdges(
151 const variable_list& roots,
152 edge_list& rootEdges,
153 variable_list& grads) {
154 TORCH_CHECK(!roots.empty(), "No tensors provided for gradient computation.");
155 TORCH_INTERNAL_ASSERT(rootEdges.empty());
156 TORCH_INTERNAL_ASSERT(grads.empty());
157
158 // Verify roots are all scalar and require gradients.
159 for (const auto& root : roots) {
160 TORCH_CHECK(root.requires_grad(), "requires_grad not set on root");
161 TORCH_CHECK(
162 root.numel() == 1,
163 root.name(),
164 " is not a scalar, all roots need to be scalar");
165 TORCH_CHECK(
166 root.grad_fn(),
167 root.name(),
168 " does not have a valid gradient function.");
169
170 // Compute the root edges and generate the appropriate gradients.
171 rootEdges.push_back(torch::autograd::impl::gradient_edge(root));
172 grads.push_back(at::ones_like(root, LEGACY_CONTIGUOUS_MEMORY_FORMAT));
173 }
174
175 // Validate rootEdges and grads.
176 validate_outputs(
177 rootEdges, grads, [](const std::string& msg) { return msg; });
178 }
179
computeDependencies(const ContextPtr & autogradContext,const edge_list & rootEdges,const variable_list & grads,const std::shared_ptr<Node> & graphRoot,edge_list & outputEdges,bool retainGraph)180 void DistEngine::computeDependencies(
181 const ContextPtr& autogradContext,
182 const edge_list& rootEdges,
183 const variable_list& grads,
184 const std::shared_ptr<Node>& graphRoot,
185 edge_list& outputEdges,
186 bool retainGraph) {
187 TORCH_INTERNAL_ASSERT(graphRoot, "graphRoot is null!");
188
189 // Store root nodes so we can traverse through the graph later
190 // e.g., for get_current_graph_task_execution_order
191 c10::SmallVector<Node*, 4> temp_roots{rootEdges.size()};
192 for (const auto i : c10::irange(rootEdges.size())) {
193 temp_roots[i] = rootEdges[i].function.get();
194 }
195
196 // Build the graph task and graph root.
197 // NOTE: we don't need to build and pass a cpu_ready_queue to GraphTask
198 // as we use execute_graph_task_until_ready_queue_empty, which will build
199 // a separate ReadyQueue for each call.
200 auto graphTask = std::make_shared<GraphTask>(
201 /* keep_graph */ retainGraph,
202 /* create_graph */ false,
203 /* depth */ 0,
204 /* cpu_ready_queue */ global_cpu_ready_queue_,
205 /* graph_roots */ temp_roots,
206 /* exit_on_error */ true);
207
208 // Run BFS to traverse the graph locally. The roots of the graph are
209 // GraphRoot and all send functions for this autograd context.
210 std::unordered_set<Node*> seen;
211 std::queue<Node*> queue;
212 queue.push(static_cast<Node*>(graphRoot.get()));
213
214 auto sendFunctions = autogradContext->sendFunctions();
215
216 // Add all the send functions to the queue as roots.
217 for (const auto& mapEntry : sendFunctions) {
218 // Increment 'outstanding_tasks_' for GraphTask for each send_function
219 // since we want the local autograd engine to wait for all of them.
220 graphTask->outstanding_tasks_++;
221 queue.push(mapEntry.second.get());
222 }
223
224 bool will_use_accelerator = false;
225
226 edge_list recvBackwardEdges;
227 // Traverse the graph.
228 auto& dependencies = graphTask->dependencies_;
229 while (!queue.empty()) {
230 auto fn = queue.front();
231 queue.pop();
232
233 if (!will_use_accelerator) {
234 will_use_accelerator = fn->stream().has_value();
235 }
236
237 for (const auto& edge : fn->next_edges()) {
238 if (auto nextFn = edge.function.get()) {
239 dependencies[nextFn] += 1;
240 const bool wasInserted = seen.insert(nextFn).second;
241 if (wasInserted) {
242 // Seeing this function for the first time.
243 queue.push(nextFn);
244
245 if (nextFn->next_edges().empty()) {
246 TORCH_INTERNAL_ASSERT(
247 dynamic_cast<AccumulateGrad*>(nextFn) ||
248 dynamic_cast<RecvRpcBackward*>(nextFn));
249 // We have found a leaf node which should be either AccumulateGrad
250 // or RecvRpcBackward. Record the function
251 // to ensure we don't execute it and instead accumulate the grads on
252 // the autograd context. These functions would be passed in as the
253 // 'outputs' parameter of the vanilla autograd engine.
254
255 // We don't accumulate any grads in the context for RecvRpcBackward.
256 // RecvRpcBackward is added as an output edge to indicate it is a
257 // leaf node and this helps in properly computing dependencies for
258 // the local autograd graph. Putting RecvRpcBackward in
259 // 'outputEdges' means that this function needs to be executed
260 // (inline with our assumption for FAST mode that all send/recv
261 // functions are valid in the backward pass), and as a result all of
262 // its ancestors need to be executed as well.
263 if (dynamic_cast<RecvRpcBackward*>(nextFn)) {
264 recvBackwardEdges.emplace_back(edge);
265 }
266 outputEdges.emplace_back(edge);
267 }
268 }
269 }
270 }
271 }
272
273 if (will_use_accelerator) {
274 // Collects current streams for CUDA/ROCM devices where this process has a
275 // context, so graphTask::exec_post_processing can sync them with
276 // leaf_streams.
277 graphTask->stash_current_streams();
278 }
279
280 // Now lets compute which functions need to be executed. The algorithm is as
281 // follows:
282 // 1. Create a dummy GraphRoot which points to all 'send' functions for this
283 // context and the original graphRoot. Run 'init_to_execute' with the
284 // outputEdges and the dummy GraphRoot. This ensures we mark
285 // appropriate functions as needed if they are reachable only from a
286 // specific 'send' function locally and not necessarily from the provided
287 // roots.
288 // 2. For all edges in 'outputEdges' which point to 'RecvRpcBackward', mark
289 // those functions as needed for execution. The reason for this is that
290 // 'init_to_execute', will mark these as not needed. But 'RecvRpcBackward'
291 // is unique in the sense that we use it as a leaf node in graph to compute
292 // needed execution accurately, but unlike AccumulateGrad, we do need to
293 // execute this function.
294 if (!outputEdges.empty()) {
295 // Compute 'needed execution' starting from all 'send' functions and the
296 // original graphRoot.
297 edge_list edges;
298 // Create some dummy edges (input_nr not important for init_to_execute).
299 for (const auto& mapEntry : sendFunctions) {
300 edges.emplace_back(mapEntry.second, 0);
301 }
302
303 // Add the original graphRoot as an edge.
304 edges.emplace_back(graphRoot, 0);
305
306 // Create a dummy GraphRoot and run init_to_execute with it.
307 GraphRoot dummyRoot(edges, {});
308 graphTask->init_to_execute(
309 dummyRoot, outputEdges, /*accumulate_grad=*/false, /*min_topo_nr=*/0);
310 for (auto& mapEntry : graphTask->exec_info_) {
311 auto& execInfo = mapEntry.second;
312 if (!execInfo.captures_) {
313 continue;
314 }
315 auto fn = mapEntry.first;
316 // There may be nodes other than 'AccumulateGrad', e.g. RecvRPCBackward,
317 // to be captured.
318 if (auto accumulateGradFn = dynamic_cast<AccumulateGrad*>(fn)) {
319 for (auto& capture : *execInfo.captures_) {
320 // Capture hooks are technically deprecated, but as an exception below
321 // is the single and only instance of capture hooks usage that we
322 // support. See NOTE [Deprecated capture hooks] for more context.
323 capture.DO_NOT_USE_DEPRECATED_register_capture_hook(
324 std::make_unique<DistAccumulateGradCaptureHook>(
325 std::dynamic_pointer_cast<AccumulateGrad>(
326 accumulateGradFn->shared_from_this()),
327 autogradContext));
328 }
329 }
330 }
331
332 // Mark all 'RecvRPCBackward' as needing execution.
333 for (const auto& recvBackwardEdge : recvBackwardEdges) {
334 graphTask->exec_info_[recvBackwardEdge.function.get()].needed_ = true;
335 }
336 }
337
338 // Set graph task owner in a single thread since concurrent access to
339 // 'owner_' field is not permitted.
340 graphTask->owner_ = torch::autograd::CPU_DEVICE;
341
342 // Let autograd context take ownership of the GraphTask.
343 autogradContext->setGraphTask(std::move(graphTask));
344 }
345
execute_graph_task_until_ready_queue_empty(NodeTask && node_task,bool incrementOutstandingTasks)346 void DistEngine::execute_graph_task_until_ready_queue_empty(
347 NodeTask&& node_task,
348 bool incrementOutstandingTasks) {
349 engine_.initialize_device_threads_pool();
350 // Create a ready queue per call to traverse the graph_task from
351 // root_to_execute This allow concurrent execution of the same GraphTask from
352 // different threads
353 std::shared_ptr<ReadyQueue> cpu_ready_queue = std::make_shared<ReadyQueue>();
354 auto graph_task = node_task.base_.lock();
355 if (graph_task == nullptr) {
356 LOG(ERROR) << "GraphTask has expired for NodeTask: "
357 << node_task.fn_->name() << ", skipping execution.";
358 return;
359 }
360
361 cpu_ready_queue->push(std::move(node_task), incrementOutstandingTasks);
362
363 torch::autograd::set_device(torch::autograd::CPU_DEVICE);
364 while (!cpu_ready_queue->empty()) {
365 std::shared_ptr<GraphTask> local_graph_task;
366 {
367 // Scope this block of execution since NodeTask is not needed after this
368 // block and can be deallocated (release any references to grad tensors
369 // as part of inputs_)
370 NodeTask task = cpu_ready_queue->pop();
371 if (!(local_graph_task = task.base_.lock())) {
372 continue;
373 }
374 if (task.fn_ && !local_graph_task->has_error_.load()) {
375 at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_);
376 try {
377 GraphTaskGuard guard(local_graph_task);
378 engine_.evaluate_function(
379 local_graph_task, task.fn_.get(), task.inputs_, cpu_ready_queue);
380 } catch (std::exception& e) {
381 engine_.thread_on_exception(local_graph_task, task.fn_, e);
382 // break the loop in error so that we immediately stop the execution
383 // of this GraphTask, mark it completed if necessary and return the
384 // future with proper ErrorMessage
385 break;
386 }
387 }
388 }
389 // Decrement the outstanding task.
390 --local_graph_task->outstanding_tasks_;
391 }
392 // Check if we've completed execution.
393 if (graph_task->completed()) {
394 // We don't need to explicitly notify the owner thread, since
395 // 'mark_as_completed_and_run_post_processing' would mark the Future as
396 // completed and this would notify the owner thread that the task has been
397 // completed.
398 graph_task->mark_as_completed_and_run_post_processing();
399 }
400 }
401
402 c10::intrusive_ptr<c10::ivalue::Future> DistEngine::
runEngineAndAccumulateGradients(const ContextPtr & autogradContext,const std::shared_ptr<Node> & graphRoot,const edge_list & outputEdges,bool incrementOutstandingTasks)403 runEngineAndAccumulateGradients(
404 const ContextPtr& autogradContext,
405 const std::shared_ptr<Node>& graphRoot,
406 const edge_list& outputEdges,
407 bool incrementOutstandingTasks) {
408 // Cleanup previous state for outstanding RPCs. Outstanding RPCs could be
409 // lingering if we're running backward multiple times and some of the
410 // passes ran into errors.
411 autogradContext->clearOutstandingRpcs();
412 auto graphTask = autogradContext->retrieveGraphTask();
413 at::launch([this, graphTask, graphRoot, incrementOutstandingTasks]() {
414 execute_graph_task_until_ready_queue_empty(
415 /*node_task*/ NodeTask(graphTask, graphRoot, InputBuffer(0)),
416 /*incrementOutstandingTasks*/ incrementOutstandingTasks);
417 });
418 // Use a reference here to avoid refcount bump on futureGrads.
419 auto& futureGrads = graphTask->future_result_;
420
421 // Build a future that waits for the callbacks to execute (since callbacks
422 // execute after the original future is completed). This ensures we return a
423 // future that waits for all gradient accumulation to finish.
424 auto accumulateGradFuture =
425 c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
426
427 futureGrads->addCallback([autogradContext, outputEdges, accumulateGradFuture](
428 c10::ivalue::Future& futureGrads) {
429 if (futureGrads.hasError()) {
430 // Don't accumulate gradients if we receive an error.
431 // We must add the node information here since DistEngine::execute
432 // waits on accumulateGradFuture and will throw an exception once we
433 // set the error below.
434 std::string errorMsg = c10::str(
435 "Error on Node ",
436 DistAutogradContainer::getInstance().getWorkerId(),
437 ": ",
438 futureGrads.tryRetrieveErrorMessage());
439 accumulateGradFuture->setError(std::make_exception_ptr(
440 c10::ivalue::Future::FutureError(std::move(errorMsg))));
441 return;
442 }
443
444 try {
445 const variable_list& grads = futureGrads.constValue().toTensorVector();
446 TORCH_INTERNAL_ASSERT(grads.size() == outputEdges.size());
447 accumulateGradFuture->markCompleted(c10::IValue());
448 } catch (std::exception& e) {
449 accumulateGradFuture->setErrorIfNeeded(std::current_exception());
450 }
451 });
452
453 return accumulateGradFuture;
454 }
455
executeSendFunctionAsync(const ContextPtr & autogradContext,const std::shared_ptr<SendRpcBackward> & sendFunction,bool retainGraph)456 c10::intrusive_ptr<c10::ivalue::Future> DistEngine::executeSendFunctionAsync(
457 const ContextPtr& autogradContext,
458 const std::shared_ptr<SendRpcBackward>& sendFunction,
459 bool retainGraph) {
460 // Typically the local autograd engine ensures stream synchronizations between
461 // nodes in the graph. However, for distributed autograd the sendFunction
462 // inputs might have been retrieved over the wire on a separate stream and the
463 // sendFunction itself runs on a different stream. As a result, we need to
464 // manually synchronize those two streams here.
465 const auto& send_backward_stream = sendFunction->stream();
466 if (send_backward_stream) {
467 for (const auto& grad : sendFunction->getGrads()) {
468 const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA};
469 const auto default_stream = guard.getStream(grad.device());
470 if (send_backward_stream != default_stream) {
471 auto event = c10::Event{c10::DeviceType::CUDA};
472 event.record(default_stream);
473 send_backward_stream->wait(event);
474 }
475 }
476 }
477
478 std::unique_lock<std::mutex> lock(initializedContextIdsLock_);
479 if (initializedContextIds_.find(autogradContext->contextId()) ==
480 initializedContextIds_.end()) {
481 edge_list outputEdges;
482 // Pass in a dummy graphRoot since all send functions are the roots.
483 auto dummyRoot = std::make_shared<GraphRoot>(edge_list(), variable_list());
484 computeDependencies(
485 autogradContext, {}, {}, dummyRoot, outputEdges, retainGraph);
486
487 // Mark the autograd context id as initialized and unlock.
488 initializedContextIds_.insert(autogradContext->contextId());
489 lock.unlock();
490
491 // Enqueue the current send function.
492 auto graphTask = autogradContext->retrieveGraphTask();
493 // Run the autograd engine.
494 auto accumulateGradFuture = runEngineAndAccumulateGradients(
495 autogradContext,
496 sendFunction,
497 outputEdges,
498 /*incrementOutstandingTasks=*/false);
499
500 // Build the 'uber' future that waits for everything.
501 auto callbackFuture =
502 c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
503
504 accumulateGradFuture->addCallback(
505 [autogradContext,
506 callbackFuture](c10::ivalue::Future& accumulateGradFuture) {
507 try {
508 if (accumulateGradFuture.hasError()) {
509 // Perform cleanup at the end of the backward pass (before we mark
510 // the future as completed).
511 DistEngine::getInstance().cleanupBackwardPass(autogradContext);
512
513 // Skip any further processing on errors.
514 callbackFuture->setError(accumulateGradFuture.exception_ptr());
515 return;
516 }
517
518 // Wait for all RPCs after the autograd engine is done.
519 auto rpcFuture =
520 autogradContext->clearAndWaitForOutstandingRpcsAsync();
521 rpcFuture->addCallback([callbackFuture, autogradContext](
522 c10::ivalue::Future& rpcFuture) {
523 try {
524 // Perform cleanup at the end of the backward pass (before
525 // we mark the future as completed).
526 DistEngine::getInstance().cleanupBackwardPass(autogradContext);
527 } catch (std::exception& e) {
528 callbackFuture->setErrorIfNeeded(std::current_exception());
529 return;
530 }
531
532 // Finally mark the 'uber' future as completed.
533 if (!rpcFuture.hasError()) {
534 callbackFuture->markCompleted(c10::IValue());
535 } else {
536 callbackFuture->setError(rpcFuture.exception_ptr());
537 }
538 });
539 } catch (std::exception& e) {
540 callbackFuture->setErrorIfNeeded(std::current_exception());
541 }
542 });
543
544 // Return the future which waits for all async processing to be done.
545 return callbackFuture;
546 } else {
547 lock.unlock();
548 auto graphTask = autogradContext->retrieveGraphTask();
549 at::launch([this, graphTask, sendFunction]() {
550 execute_graph_task_until_ready_queue_empty(
551 /*node_task*/ NodeTask(graphTask, sendFunction, InputBuffer(0)),
552 /*incrementOutstandingTasks*/ false);
553 });
554 auto fut = c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
555 fut->markCompleted(c10::IValue());
556 return fut;
557 }
558 }
559
execute(int64_t contextId,const variable_list & roots,bool retainGraph)560 void DistEngine::execute(
561 int64_t contextId,
562 const variable_list& roots,
563 bool retainGraph) {
564 // Retrieve the context for the given context_id. This will throw if the
565 // context_id is invalid.
566 auto autogradContext =
567 DistAutogradContainer::getInstance().retrieveContext(contextId);
568
569 // Perform initial pre-processing.
570 edge_list rootEdges;
571 variable_list grads;
572 validateRootsAndRetrieveEdges(roots, rootEdges, grads);
573
574 std::shared_ptr<Node> graphRoot =
575 std::make_shared<GraphRoot>(rootEdges, grads);
576 edge_list outputEdges;
577 // Compute dependencies locally, starting from all roots and all 'send'
578 // functions.
579 {
580 std::lock_guard<std::mutex> guard(initializedContextIdsLock_);
581 // Context should not have been initialized already.
582 TORCH_INTERNAL_ASSERT(
583 initializedContextIds_.find(autogradContext->contextId()) ==
584 initializedContextIds_.end());
585
586 computeDependencies(
587 autogradContext, rootEdges, grads, graphRoot, outputEdges, retainGraph);
588
589 // Mark the autograd context id as initialized.
590 initializedContextIds_.insert(autogradContext->contextId());
591 }
592
593 BackwardPassCleanupGuard guard(autogradContext);
594
595 // This needs to be blocking and as a result we wait for the future to
596 // complete.
597 runEngineAndAccumulateGradients(autogradContext, graphRoot, outputEdges)
598 ->waitAndThrow();
599
600 // Wait for all of the outstanding rpcs to complete.
601 autogradContext->clearAndWaitForOutstandingRpcsAsync()->waitAndThrow();
602 }
603
cleanupBackwardPass(const ContextPtr & autogradContext)604 void DistEngine::cleanupBackwardPass(const ContextPtr& autogradContext) {
605 // Validate only the GraphTask is holding a reference to the Future
606 // which holds gradients for the backward pass. This ensures that
607 // after 'resetGraphTask' is called below, there are no remaining
608 // references left to the gradients for the backward pass.
609 //
610 // This ensures our 'use_count' checks in
611 // AccumulateGrad::accumulateGrad are correct and we're
612 // not leaking any references to the gradients anywhere else.
613 const auto& futureGrads =
614 autogradContext->retrieveGraphTask()->future_result_;
615 TORCH_INTERNAL_ASSERT(futureGrads.use_count() == 1);
616
617 // Reset the graph task once we're done with all processing.
618 autogradContext->resetGraphTask();
619
620 // Clear any outstanding rpcs.
621 autogradContext->clearOutstandingRpcs();
622
623 // Clear the context id once we're done with the autograd engine
624 // processing.
625 std::lock_guard<std::mutex> guard(initializedContextIdsLock_);
626 initializedContextIds_.erase(autogradContext->contextId());
627 }
628
numBackwardPasses() const629 size_t DistEngine::numBackwardPasses() const {
630 std::lock_guard<std::mutex> guard(initializedContextIdsLock_);
631 return initializedContextIds_.size();
632 }
633
getDebugInfo() const634 std::unordered_map<std::string, int> DistEngine::getDebugInfo() const {
635 std::unordered_map<std::string, int> debugInfo;
636 debugInfo[kNumBackwardPasses] = numBackwardPasses();
637 debugInfo[kNumAutogradContexts] =
638 DistAutogradContainer::getInstance().numAutogradContexts();
639 return debugInfo;
640 }
641
642 } // namespace autograd
643 } // namespace distributed
644 } // namespace torch
645