xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/engine/dist_engine.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <queue>
2 
3 #include <ATen/Parallel.h>
4 #include <c10/core/Event.h>
5 #include <c10/util/DeadlockDetection.h>
6 #include <c10/util/irange.h>
7 #include <c10/util/thread_name.h>
8 #include <torch/csrc/autograd/functions/accumulate_grad.h>
9 #include <torch/csrc/autograd/input_buffer.h>
10 #include <torch/csrc/distributed/autograd/context/container.h>
11 #include <torch/csrc/distributed/autograd/engine/dist_engine.h>
12 
13 namespace torch {
14 namespace distributed {
15 namespace autograd {
16 
17 using torch::autograd::AccumulateGrad;
18 using torch::autograd::edge_list;
19 using torch::autograd::Engine;
20 using torch::autograd::GraphRoot;
21 using torch::autograd::GraphTask;
22 using torch::autograd::GraphTaskGuard;
23 using torch::autograd::InputBuffer;
24 using torch::autograd::Node;
25 using torch::autograd::NodeTask;
26 using torch::autograd::ReadyQueue;
27 using torch::autograd::validate_outputs;
28 using torch::autograd::variable_list;
29 
30 static constexpr const char* kNumBackwardPasses = "num_current_backward_passes";
31 static constexpr const char* kNumAutogradContexts = "num_autograd_contexts";
32 
33 // This hook does 3 things:
34 //   1. Call pre hooks of the original AccumulateGrad to modify the input grad.
35 //   2. Accumurate the guard to RPC context.
36 //   3. Call post hooks of the original AccumulateGrad.
37 class DistAccumulateGradCaptureHook
38     : public GraphTask::ExecInfo::Capture::GradCaptureHook {
39  public:
DistAccumulateGradCaptureHook(std::shared_ptr<AccumulateGrad> accumulateGrad,ContextPtr autogradContext)40   DistAccumulateGradCaptureHook(
41       std::shared_ptr<AccumulateGrad> accumulateGrad,
42       ContextPtr autogradContext)
43       : accumulateGrad_(std::move(accumulateGrad)),
44         autogradContext_(std::move(autogradContext)) {}
45 
operator ()(const at::Tensor & grad)46   at::Tensor operator()(const at::Tensor& grad) override {
47     ThreadLocalDistAutogradContext contextGuard{ContextPtr(autogradContext_)};
48     variable_list inputGrads = {grad};
49     // It's intended that pre/post hooks are still called even if the grad is
50     // undefined here.
51     for (const auto& hook : accumulateGrad_->pre_hooks()) {
52       inputGrads = (*hook)(inputGrads);
53     }
54     // It is possible that the grad is not defined since a separate
55     // invocation of the autograd engine on the same node might actually
56     // compute this gradient.
57     if (inputGrads[0].defined()) {
58       // There are 3 internal references to 'inputGrads[0]' at this moment:
59       //   1. 'inputGrads[0]' in this function.
60       //   2. 'graph_task->captured_vars_' on the callsite in the local engine.
61       //   3. 'InputBuffer& inputs' on the callsite as the inputs of the
62       //   function node.
63       autogradContext_->accumulateGrad(
64           accumulateGrad_->variable, inputGrads[0], 3 /* num_expected_refs */);
65     }
66     const variable_list kEmptyOutput;
67     for (const auto& hook : accumulateGrad_->post_hooks()) {
68       (*hook)(kEmptyOutput, inputGrads);
69     }
70     return inputGrads[0];
71   }
72 
73  private:
74   std::shared_ptr<AccumulateGrad> accumulateGrad_;
75   ContextPtr autogradContext_;
76 };
77 
globalCpuThread(const std::shared_ptr<ReadyQueue> & ready_queue)78 void DistEngine::globalCpuThread(
79     const std::shared_ptr<ReadyQueue>& ready_queue) {
80   c10::setThreadName("pt_dist_engine");
81   while (true) {
82     NodeTask task = ready_queue->pop();
83     if (task.isShutdownTask_) {
84       // Need to shutdown this thread.
85       C10_LOG_API_USAGE_ONCE("torch.autograd.thread_shutdown");
86       break;
87     }
88 
89     auto graphTask = task.base_.lock();
90     if (graphTask == nullptr) {
91       // GraphTask has expired, ignore and continue processing.
92       continue;
93     }
94 
95     // Launch the execution on a JIT thread.
96     at::launch([this,
97                 graphTask,
98                 graphRoot = task.fn_,
99                 variables =
100                     InputBuffer::variables(std::move(task.inputs_))]() mutable {
101       InputBuffer inputs(variables.size());
102       for (const auto i : c10::irange(variables.size())) {
103         inputs.add(i, std::move(variables[i]), std::nullopt, std::nullopt);
104       }
105       execute_graph_task_until_ready_queue_empty(
106           /*node_task*/ NodeTask(graphTask, graphRoot, std::move(inputs)),
107           /*incrementOutstandingTasks*/ false);
108     });
109   }
110 }
111 
DistEngine()112 DistEngine::DistEngine()
113     : initializedContextIds_(),
114       engine_(Engine::get_default_engine()),
115       global_cpu_ready_queue_(std::make_shared<ReadyQueue>()),
116       global_cpu_thread_(
117           &DistEngine::globalCpuThread,
118           this,
119           global_cpu_ready_queue_) {
120   // Note [GPU to CPU continuations]
121   // ~~~~~~~~~~~~~~~~~~~~~~~~~~
122   // Initialize a single CPU thread to execute continuations from GPU
123   // tasks. The multithreaded structure for the distributed engine works
124   // well only for CPU tasks. If we have an order of tasks like
125   // CPU->GPU->CPU, distributed autograd has no thread to execute the last
126   // CPU task on. To fix this, we introduce a global CPU thread to handle
127   // such situations and it will be responsible for executing these CPU
128   // tasks. The CPU thread has its own ready_queue which is used as the
129   // cpu_ready_queue for all GraphTasks for DistEngine. This ensures all GPU
130   // to CPU continuations are enqueued on this thread. The global CPU thread
131   // simply dequeues tasks from the global queue and calls
132   // "execute_graph_task_until_ready_queue_empty" on a JIT thread to execute the
133   // appropriate task.
134   global_cpu_thread_.detach();
135 }
136 
~DistEngine()137 DistEngine::~DistEngine() {
138   // Ensure we shutdown the CPU thread.
139   TORCH_ASSERT_NO_GIL_WITHOUT_PYTHON_DEP();
140   global_cpu_ready_queue_->pushShutdownTask();
141   global_cpu_thread_.join();
142 }
143 
getInstance()144 DistEngine& DistEngine::getInstance() {
145   // Leaky singleton to avoid module destructor race.
146   static DistEngine* engine = new DistEngine();
147   return *engine;
148 }
149 
validateRootsAndRetrieveEdges(const variable_list & roots,edge_list & rootEdges,variable_list & grads)150 void DistEngine::validateRootsAndRetrieveEdges(
151     const variable_list& roots,
152     edge_list& rootEdges,
153     variable_list& grads) {
154   TORCH_CHECK(!roots.empty(), "No tensors provided for gradient computation.");
155   TORCH_INTERNAL_ASSERT(rootEdges.empty());
156   TORCH_INTERNAL_ASSERT(grads.empty());
157 
158   // Verify roots are all scalar and require gradients.
159   for (const auto& root : roots) {
160     TORCH_CHECK(root.requires_grad(), "requires_grad not set on root");
161     TORCH_CHECK(
162         root.numel() == 1,
163         root.name(),
164         " is not a scalar, all roots need to be scalar");
165     TORCH_CHECK(
166         root.grad_fn(),
167         root.name(),
168         " does not have a valid gradient function.");
169 
170     // Compute the root edges and generate the appropriate gradients.
171     rootEdges.push_back(torch::autograd::impl::gradient_edge(root));
172     grads.push_back(at::ones_like(root, LEGACY_CONTIGUOUS_MEMORY_FORMAT));
173   }
174 
175   // Validate rootEdges and grads.
176   validate_outputs(
177       rootEdges, grads, [](const std::string& msg) { return msg; });
178 }
179 
computeDependencies(const ContextPtr & autogradContext,const edge_list & rootEdges,const variable_list & grads,const std::shared_ptr<Node> & graphRoot,edge_list & outputEdges,bool retainGraph)180 void DistEngine::computeDependencies(
181     const ContextPtr& autogradContext,
182     const edge_list& rootEdges,
183     const variable_list& grads,
184     const std::shared_ptr<Node>& graphRoot,
185     edge_list& outputEdges,
186     bool retainGraph) {
187   TORCH_INTERNAL_ASSERT(graphRoot, "graphRoot is null!");
188 
189   // Store root nodes so we can traverse through the graph later
190   // e.g., for get_current_graph_task_execution_order
191   c10::SmallVector<Node*, 4> temp_roots{rootEdges.size()};
192   for (const auto i : c10::irange(rootEdges.size())) {
193     temp_roots[i] = rootEdges[i].function.get();
194   }
195 
196   // Build the graph task and graph root.
197   // NOTE: we don't need to build and pass a cpu_ready_queue to GraphTask
198   // as we use execute_graph_task_until_ready_queue_empty, which will build
199   // a separate ReadyQueue for each call.
200   auto graphTask = std::make_shared<GraphTask>(
201       /* keep_graph */ retainGraph,
202       /* create_graph */ false,
203       /* depth */ 0,
204       /* cpu_ready_queue */ global_cpu_ready_queue_,
205       /* graph_roots */ temp_roots,
206       /* exit_on_error */ true);
207 
208   // Run BFS to traverse the graph locally. The roots of the graph are
209   // GraphRoot and all send functions for this autograd context.
210   std::unordered_set<Node*> seen;
211   std::queue<Node*> queue;
212   queue.push(static_cast<Node*>(graphRoot.get()));
213 
214   auto sendFunctions = autogradContext->sendFunctions();
215 
216   // Add all the send functions to the queue as roots.
217   for (const auto& mapEntry : sendFunctions) {
218     // Increment 'outstanding_tasks_' for GraphTask for each send_function
219     // since we want the local autograd engine to wait for all of them.
220     graphTask->outstanding_tasks_++;
221     queue.push(mapEntry.second.get());
222   }
223 
224   bool will_use_accelerator = false;
225 
226   edge_list recvBackwardEdges;
227   // Traverse the graph.
228   auto& dependencies = graphTask->dependencies_;
229   while (!queue.empty()) {
230     auto fn = queue.front();
231     queue.pop();
232 
233     if (!will_use_accelerator) {
234       will_use_accelerator = fn->stream().has_value();
235     }
236 
237     for (const auto& edge : fn->next_edges()) {
238       if (auto nextFn = edge.function.get()) {
239         dependencies[nextFn] += 1;
240         const bool wasInserted = seen.insert(nextFn).second;
241         if (wasInserted) {
242           // Seeing this function for the first time.
243           queue.push(nextFn);
244 
245           if (nextFn->next_edges().empty()) {
246             TORCH_INTERNAL_ASSERT(
247                 dynamic_cast<AccumulateGrad*>(nextFn) ||
248                 dynamic_cast<RecvRpcBackward*>(nextFn));
249             // We have found a leaf node which should be either AccumulateGrad
250             // or RecvRpcBackward. Record the function
251             // to ensure we don't execute it and instead accumulate the grads on
252             // the autograd context. These functions would be passed in as the
253             // 'outputs' parameter of the vanilla autograd engine.
254 
255             // We don't accumulate any grads in the context for RecvRpcBackward.
256             // RecvRpcBackward is added as an output edge to indicate it is a
257             // leaf node and this helps in properly computing dependencies for
258             // the local autograd graph. Putting RecvRpcBackward in
259             // 'outputEdges' means that this function needs to be executed
260             // (inline with our assumption for FAST mode that all send/recv
261             // functions are valid in the backward pass), and as a result all of
262             //  its ancestors need to be executed as well.
263             if (dynamic_cast<RecvRpcBackward*>(nextFn)) {
264               recvBackwardEdges.emplace_back(edge);
265             }
266             outputEdges.emplace_back(edge);
267           }
268         }
269       }
270     }
271   }
272 
273   if (will_use_accelerator) {
274     // Collects current streams for CUDA/ROCM devices where this process has a
275     // context, so graphTask::exec_post_processing can sync them with
276     // leaf_streams.
277     graphTask->stash_current_streams();
278   }
279 
280   // Now lets compute which functions need to be executed. The algorithm is as
281   // follows:
282   // 1. Create a dummy GraphRoot which points to all 'send' functions for this
283   //    context and the original graphRoot. Run 'init_to_execute' with the
284   //    outputEdges and the dummy GraphRoot. This ensures we mark
285   //    appropriate functions as needed if they are reachable only from a
286   //    specific 'send' function locally and not necessarily from the provided
287   //    roots.
288   // 2. For all edges in 'outputEdges' which point to 'RecvRpcBackward', mark
289   //    those functions as needed for execution. The reason for this is that
290   //    'init_to_execute', will mark these as not needed. But 'RecvRpcBackward'
291   //    is unique in the sense that we use it as a leaf node in graph to compute
292   //    needed execution accurately, but unlike AccumulateGrad, we do need to
293   //    execute this function.
294   if (!outputEdges.empty()) {
295     // Compute 'needed execution' starting from all 'send' functions and the
296     // original graphRoot.
297     edge_list edges;
298     // Create some dummy edges (input_nr not important for init_to_execute).
299     for (const auto& mapEntry : sendFunctions) {
300       edges.emplace_back(mapEntry.second, 0);
301     }
302 
303     // Add the original graphRoot as an edge.
304     edges.emplace_back(graphRoot, 0);
305 
306     // Create a dummy GraphRoot and run init_to_execute with it.
307     GraphRoot dummyRoot(edges, {});
308     graphTask->init_to_execute(
309         dummyRoot, outputEdges, /*accumulate_grad=*/false, /*min_topo_nr=*/0);
310     for (auto& mapEntry : graphTask->exec_info_) {
311       auto& execInfo = mapEntry.second;
312       if (!execInfo.captures_) {
313         continue;
314       }
315       auto fn = mapEntry.first;
316       // There may be nodes other than 'AccumulateGrad', e.g. RecvRPCBackward,
317       // to be captured.
318       if (auto accumulateGradFn = dynamic_cast<AccumulateGrad*>(fn)) {
319         for (auto& capture : *execInfo.captures_) {
320           // Capture hooks are technically deprecated, but as an exception below
321           // is the single and only instance of capture hooks usage that we
322           // support. See NOTE [Deprecated capture hooks] for more context.
323           capture.DO_NOT_USE_DEPRECATED_register_capture_hook(
324               std::make_unique<DistAccumulateGradCaptureHook>(
325                   std::dynamic_pointer_cast<AccumulateGrad>(
326                       accumulateGradFn->shared_from_this()),
327                   autogradContext));
328         }
329       }
330     }
331 
332     // Mark all 'RecvRPCBackward' as needing execution.
333     for (const auto& recvBackwardEdge : recvBackwardEdges) {
334       graphTask->exec_info_[recvBackwardEdge.function.get()].needed_ = true;
335     }
336   }
337 
338   // Set graph task owner in a single thread since concurrent access to
339   // 'owner_' field is not permitted.
340   graphTask->owner_ = torch::autograd::CPU_DEVICE;
341 
342   // Let autograd context take ownership of the GraphTask.
343   autogradContext->setGraphTask(std::move(graphTask));
344 }
345 
execute_graph_task_until_ready_queue_empty(NodeTask && node_task,bool incrementOutstandingTasks)346 void DistEngine::execute_graph_task_until_ready_queue_empty(
347     NodeTask&& node_task,
348     bool incrementOutstandingTasks) {
349   engine_.initialize_device_threads_pool();
350   // Create a ready queue per call to traverse the graph_task from
351   // root_to_execute This allow concurrent execution of the same GraphTask from
352   // different threads
353   std::shared_ptr<ReadyQueue> cpu_ready_queue = std::make_shared<ReadyQueue>();
354   auto graph_task = node_task.base_.lock();
355   if (graph_task == nullptr) {
356     LOG(ERROR) << "GraphTask has expired for NodeTask: "
357                << node_task.fn_->name() << ", skipping execution.";
358     return;
359   }
360 
361   cpu_ready_queue->push(std::move(node_task), incrementOutstandingTasks);
362 
363   torch::autograd::set_device(torch::autograd::CPU_DEVICE);
364   while (!cpu_ready_queue->empty()) {
365     std::shared_ptr<GraphTask> local_graph_task;
366     {
367       // Scope this block of execution since NodeTask is not needed after this
368       // block and can be deallocated (release any references to grad tensors
369       // as part of inputs_)
370       NodeTask task = cpu_ready_queue->pop();
371       if (!(local_graph_task = task.base_.lock())) {
372         continue;
373       }
374       if (task.fn_ && !local_graph_task->has_error_.load()) {
375         at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_);
376         try {
377           GraphTaskGuard guard(local_graph_task);
378           engine_.evaluate_function(
379               local_graph_task, task.fn_.get(), task.inputs_, cpu_ready_queue);
380         } catch (std::exception& e) {
381           engine_.thread_on_exception(local_graph_task, task.fn_, e);
382           // break the loop in error so that we immediately stop the execution
383           // of this GraphTask, mark it completed if necessary and return the
384           // future with proper ErrorMessage
385           break;
386         }
387       }
388     }
389     // Decrement the outstanding task.
390     --local_graph_task->outstanding_tasks_;
391   }
392   // Check if we've completed execution.
393   if (graph_task->completed()) {
394     // We don't need to explicitly notify the owner thread, since
395     // 'mark_as_completed_and_run_post_processing' would mark the Future as
396     // completed and this would notify the owner thread that the task has been
397     // completed.
398     graph_task->mark_as_completed_and_run_post_processing();
399   }
400 }
401 
402 c10::intrusive_ptr<c10::ivalue::Future> DistEngine::
runEngineAndAccumulateGradients(const ContextPtr & autogradContext,const std::shared_ptr<Node> & graphRoot,const edge_list & outputEdges,bool incrementOutstandingTasks)403     runEngineAndAccumulateGradients(
404         const ContextPtr& autogradContext,
405         const std::shared_ptr<Node>& graphRoot,
406         const edge_list& outputEdges,
407         bool incrementOutstandingTasks) {
408   // Cleanup previous state for outstanding RPCs. Outstanding RPCs could be
409   // lingering if we're running backward multiple times and some of the
410   // passes ran into errors.
411   autogradContext->clearOutstandingRpcs();
412   auto graphTask = autogradContext->retrieveGraphTask();
413   at::launch([this, graphTask, graphRoot, incrementOutstandingTasks]() {
414     execute_graph_task_until_ready_queue_empty(
415         /*node_task*/ NodeTask(graphTask, graphRoot, InputBuffer(0)),
416         /*incrementOutstandingTasks*/ incrementOutstandingTasks);
417   });
418   // Use a reference here to avoid refcount bump on futureGrads.
419   auto& futureGrads = graphTask->future_result_;
420 
421   // Build a future that waits for the callbacks to execute (since callbacks
422   // execute after the original future is completed). This ensures we return a
423   // future that waits for all gradient accumulation to finish.
424   auto accumulateGradFuture =
425       c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
426 
427   futureGrads->addCallback([autogradContext, outputEdges, accumulateGradFuture](
428                                c10::ivalue::Future& futureGrads) {
429     if (futureGrads.hasError()) {
430       // Don't accumulate gradients if we receive an error.
431       // We must add the node information here since DistEngine::execute
432       // waits on accumulateGradFuture and will throw an exception once we
433       // set the error below.
434       std::string errorMsg = c10::str(
435           "Error on Node ",
436           DistAutogradContainer::getInstance().getWorkerId(),
437           ": ",
438           futureGrads.tryRetrieveErrorMessage());
439       accumulateGradFuture->setError(std::make_exception_ptr(
440           c10::ivalue::Future::FutureError(std::move(errorMsg))));
441       return;
442     }
443 
444     try {
445       const variable_list& grads = futureGrads.constValue().toTensorVector();
446       TORCH_INTERNAL_ASSERT(grads.size() == outputEdges.size());
447       accumulateGradFuture->markCompleted(c10::IValue());
448     } catch (std::exception& e) {
449       accumulateGradFuture->setErrorIfNeeded(std::current_exception());
450     }
451   });
452 
453   return accumulateGradFuture;
454 }
455 
executeSendFunctionAsync(const ContextPtr & autogradContext,const std::shared_ptr<SendRpcBackward> & sendFunction,bool retainGraph)456 c10::intrusive_ptr<c10::ivalue::Future> DistEngine::executeSendFunctionAsync(
457     const ContextPtr& autogradContext,
458     const std::shared_ptr<SendRpcBackward>& sendFunction,
459     bool retainGraph) {
460   // Typically the local autograd engine ensures stream synchronizations between
461   // nodes in the graph. However, for distributed autograd the sendFunction
462   // inputs might have been retrieved over the wire on a separate stream and the
463   // sendFunction itself runs on a different stream. As a result, we need to
464   // manually synchronize those two streams here.
465   const auto& send_backward_stream = sendFunction->stream();
466   if (send_backward_stream) {
467     for (const auto& grad : sendFunction->getGrads()) {
468       const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA};
469       const auto default_stream = guard.getStream(grad.device());
470       if (send_backward_stream != default_stream) {
471         auto event = c10::Event{c10::DeviceType::CUDA};
472         event.record(default_stream);
473         send_backward_stream->wait(event);
474       }
475     }
476   }
477 
478   std::unique_lock<std::mutex> lock(initializedContextIdsLock_);
479   if (initializedContextIds_.find(autogradContext->contextId()) ==
480       initializedContextIds_.end()) {
481     edge_list outputEdges;
482     // Pass in a dummy graphRoot since all send functions are the roots.
483     auto dummyRoot = std::make_shared<GraphRoot>(edge_list(), variable_list());
484     computeDependencies(
485         autogradContext, {}, {}, dummyRoot, outputEdges, retainGraph);
486 
487     // Mark the autograd context id as initialized and unlock.
488     initializedContextIds_.insert(autogradContext->contextId());
489     lock.unlock();
490 
491     // Enqueue the current send function.
492     auto graphTask = autogradContext->retrieveGraphTask();
493     // Run the autograd engine.
494     auto accumulateGradFuture = runEngineAndAccumulateGradients(
495         autogradContext,
496         sendFunction,
497         outputEdges,
498         /*incrementOutstandingTasks=*/false);
499 
500     // Build the 'uber' future that waits for everything.
501     auto callbackFuture =
502         c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
503 
504     accumulateGradFuture->addCallback(
505         [autogradContext,
506          callbackFuture](c10::ivalue::Future& accumulateGradFuture) {
507           try {
508             if (accumulateGradFuture.hasError()) {
509               // Perform cleanup at the end of the backward pass (before we mark
510               // the future as completed).
511               DistEngine::getInstance().cleanupBackwardPass(autogradContext);
512 
513               // Skip any further processing on errors.
514               callbackFuture->setError(accumulateGradFuture.exception_ptr());
515               return;
516             }
517 
518             // Wait for all RPCs after the autograd engine is done.
519             auto rpcFuture =
520                 autogradContext->clearAndWaitForOutstandingRpcsAsync();
521             rpcFuture->addCallback([callbackFuture, autogradContext](
522                                        c10::ivalue::Future& rpcFuture) {
523               try {
524                 // Perform cleanup at the end of the backward pass (before
525                 // we mark the future as completed).
526                 DistEngine::getInstance().cleanupBackwardPass(autogradContext);
527               } catch (std::exception& e) {
528                 callbackFuture->setErrorIfNeeded(std::current_exception());
529                 return;
530               }
531 
532               // Finally mark the 'uber' future as completed.
533               if (!rpcFuture.hasError()) {
534                 callbackFuture->markCompleted(c10::IValue());
535               } else {
536                 callbackFuture->setError(rpcFuture.exception_ptr());
537               }
538             });
539           } catch (std::exception& e) {
540             callbackFuture->setErrorIfNeeded(std::current_exception());
541           }
542         });
543 
544     // Return the future which waits for all async processing to be done.
545     return callbackFuture;
546   } else {
547     lock.unlock();
548     auto graphTask = autogradContext->retrieveGraphTask();
549     at::launch([this, graphTask, sendFunction]() {
550       execute_graph_task_until_ready_queue_empty(
551           /*node_task*/ NodeTask(graphTask, sendFunction, InputBuffer(0)),
552           /*incrementOutstandingTasks*/ false);
553     });
554     auto fut = c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
555     fut->markCompleted(c10::IValue());
556     return fut;
557   }
558 }
559 
execute(int64_t contextId,const variable_list & roots,bool retainGraph)560 void DistEngine::execute(
561     int64_t contextId,
562     const variable_list& roots,
563     bool retainGraph) {
564   // Retrieve the context for the given context_id. This will throw if the
565   // context_id is invalid.
566   auto autogradContext =
567       DistAutogradContainer::getInstance().retrieveContext(contextId);
568 
569   // Perform initial pre-processing.
570   edge_list rootEdges;
571   variable_list grads;
572   validateRootsAndRetrieveEdges(roots, rootEdges, grads);
573 
574   std::shared_ptr<Node> graphRoot =
575       std::make_shared<GraphRoot>(rootEdges, grads);
576   edge_list outputEdges;
577   // Compute dependencies locally, starting from all roots and all 'send'
578   // functions.
579   {
580     std::lock_guard<std::mutex> guard(initializedContextIdsLock_);
581     // Context should not have been initialized already.
582     TORCH_INTERNAL_ASSERT(
583         initializedContextIds_.find(autogradContext->contextId()) ==
584         initializedContextIds_.end());
585 
586     computeDependencies(
587         autogradContext, rootEdges, grads, graphRoot, outputEdges, retainGraph);
588 
589     // Mark the autograd context id as initialized.
590     initializedContextIds_.insert(autogradContext->contextId());
591   }
592 
593   BackwardPassCleanupGuard guard(autogradContext);
594 
595   // This needs to be blocking and as a result we wait for the future to
596   // complete.
597   runEngineAndAccumulateGradients(autogradContext, graphRoot, outputEdges)
598       ->waitAndThrow();
599 
600   // Wait for all of the outstanding rpcs to complete.
601   autogradContext->clearAndWaitForOutstandingRpcsAsync()->waitAndThrow();
602 }
603 
cleanupBackwardPass(const ContextPtr & autogradContext)604 void DistEngine::cleanupBackwardPass(const ContextPtr& autogradContext) {
605   // Validate only the GraphTask is holding a reference to the Future
606   // which holds gradients for the backward pass. This ensures that
607   // after 'resetGraphTask' is called below, there are no remaining
608   // references left to the gradients for the backward pass.
609   //
610   // This ensures our 'use_count' checks in
611   // AccumulateGrad::accumulateGrad are correct and we're
612   // not leaking any references to the gradients anywhere else.
613   const auto& futureGrads =
614       autogradContext->retrieveGraphTask()->future_result_;
615   TORCH_INTERNAL_ASSERT(futureGrads.use_count() == 1);
616 
617   // Reset the graph task once we're done with all processing.
618   autogradContext->resetGraphTask();
619 
620   // Clear any outstanding rpcs.
621   autogradContext->clearOutstandingRpcs();
622 
623   // Clear the context id once we're done with the autograd engine
624   // processing.
625   std::lock_guard<std::mutex> guard(initializedContextIdsLock_);
626   initializedContextIds_.erase(autogradContext->contextId());
627 }
628 
numBackwardPasses() const629 size_t DistEngine::numBackwardPasses() const {
630   std::lock_guard<std::mutex> guard(initializedContextIdsLock_);
631   return initializedContextIds_.size();
632 }
633 
getDebugInfo() const634 std::unordered_map<std::string, int> DistEngine::getDebugInfo() const {
635   std::unordered_map<std::string, int> debugInfo;
636   debugInfo[kNumBackwardPasses] = numBackwardPasses();
637   debugInfo[kNumAutogradContexts] =
638       DistAutogradContainer::getInstance().numAutogradContexts();
639   return debugInfo;
640 }
641 
642 } // namespace autograd
643 } // namespace distributed
644 } // namespace torch
645