1 #pragma once 2 3 #include <torch/csrc/autograd/function.h> 4 5 namespace torch { 6 namespace distributed { 7 namespace autograd { 8 9 // As part of our distributed autograd implementation, whenever we send an RPC 10 // from one node to another, we add a 'SendRpcBackward' autograd function to the 11 // autograd graph. This is more or less a placeholder function that is used to 12 // kickoff the autograd engine on the current worker on the backward pass. The 13 // edges for this autograd function are the inputs to the RPC method. 14 // 15 // During the backward pass, this function is queued for execution in the 16 // autograd engine which eventually runs the rest of the autograd graph. 17 struct TORCH_API SendRpcBackward : public torch::autograd::Node { 18 public: 19 torch::autograd::variable_list apply( 20 torch::autograd::variable_list&& inputs) override; 21 22 // SendRpcBackward is actually the root of an autograd graph on the local 23 // node. As a result, it doesn't receive any 'inputs', but rather the RPC 24 // framework passes gradients over to this function to kickoff local autograd 25 // computation. 26 void setGrads(const torch::autograd::variable_list& grads); 27 28 // Retrieve the grads for the function. 29 const torch::autograd::variable_list& getGrads() const; 30 31 private: 32 torch::autograd::variable_list grads_; 33 }; 34 35 } // namespace autograd 36 } // namespace distributed 37 } // namespace torch 38