xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/functions/sendrpc_backward.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/autograd/function.h>
4 
5 namespace torch {
6 namespace distributed {
7 namespace autograd {
8 
9 // As part of our distributed autograd implementation, whenever we send an RPC
10 // from one node to another, we add a 'SendRpcBackward' autograd function to the
11 // autograd graph. This is more or less a placeholder function that is used to
12 // kickoff the autograd engine on the current worker on the backward pass. The
13 // edges for this autograd function are the inputs to the RPC method.
14 //
15 // During the backward pass, this function is queued for execution in the
16 // autograd engine which eventually runs the rest of the autograd graph.
17 struct TORCH_API SendRpcBackward : public torch::autograd::Node {
18  public:
19   torch::autograd::variable_list apply(
20       torch::autograd::variable_list&& inputs) override;
21 
22   // SendRpcBackward is actually the root of an autograd graph on the local
23   // node. As a result, it doesn't receive any 'inputs', but rather the RPC
24   // framework passes gradients over to this function to kickoff local autograd
25   // computation.
26   void setGrads(const torch::autograd::variable_list& grads);
27 
28   // Retrieve the grads for the function.
29   const torch::autograd::variable_list& getGrads() const;
30 
31  private:
32   torch::autograd::variable_list grads_;
33 };
34 
35 } // namespace autograd
36 } // namespace distributed
37 } // namespace torch
38